1 module database.postgresql.packet;
2 
3 import std.algorithm;
4 import std.bitmanip;
5 import std.traits;
6 
7 import database.postgresql.exception;
8 
9 pragma(inline, true) T host(T)(T x) if (isScalarType!T)
10 {
11     static if (T.sizeof > 1)
12     {
13         return *cast(T*)(nativeToBigEndian(x).ptr);
14     }
15     else
16     {
17         return x;
18     }
19 }
20 
21 pragma(inline, true) T native(T)(T x) if (isScalarType!T)
22 {
23     static if (T.sizeof > 1)
24     {
25         return bigEndianToNative!T(*cast(ubyte[T.sizeof]*)&x);
26     }
27     else
28     {
29         return x;
30     }
31 }
32 
33 pragma(inline, true) T native(T)(ubyte* ptr) if (isScalarType!T)
34 {
35     static if (T.sizeof > 1)
36     {
37         return bigEndianToNative!T(*cast(ubyte[T.sizeof]*)ptr);
38     }
39     else
40     {
41         return x.ptr[0];
42     }
43 }
44 
45 struct InputPacket
46 {
47     @disable this();
48 
49     this(ubyte type, ubyte[]* buffer)
50     {
51         type_ = type;
52         buffer_ = buffer;
53         in_ = *buffer_;
54     }
55 
56     ubyte type() const
57     {
58         return type_;
59     }
60 
61     T peek(T)() if (!isArray!T)
62     {
63         assert(T.sizeof <= in_.length);
64         return native(*(cast(T*)in_.ptr));
65     }
66 
67     T eat(T)() if (!isArray!T)
68     {
69         assert(T.sizeof <= in_.length);
70         auto ptr = cast(T*)in_.ptr;
71         in_ = in_[T.sizeof..$];
72         return native(*ptr);
73     }
74 
75     const(char)[] peekz()
76     {
77         import core.stdc.string : strlen;
78         return cast(const(char)[])in_[0..strlen(cast(char*)in_.ptr)];
79     }
80 
81     const(char)[] eatz()
82     {
83         import core.stdc.string : strlen;
84         auto len = strlen(cast(char*)in_.ptr);
85         auto result = cast(const(char)[])in_[0..len];
86         in_ = in_[len + 1..$];
87         return result;
88     }
89 
90     void skipz()
91     {
92         import core.stdc.string : strlen;
93         auto len = strlen(cast(char*)in_.ptr);
94         in_ = in_[len + 1..$];
95     }
96 
97     T eat(T)(size_t count) if (isArray!T)
98     {
99         alias ValueType = typeof(T.init[0]);
100 
101         assert(ValueType.sizeof * count <= in_.length);
102         auto ptr = cast(ValueType*)in_.ptr;
103         in_ = in_[ValueType.sizeof * count..$];
104         return ptr[0..count];
105     }
106 
107     void expect(T)(T x)
108     {
109         if (x != eat!T)
110             throw new PgSQLProtocolException("Bad packet format");
111     }
112 
113     void skip(size_t count)
114     {
115         assert(count <= in_.length);
116         in_ = in_[count..$];
117     }
118 
119     auto countUntil(ubyte x, bool expect)
120     {
121         auto index = in_.countUntil(x);
122         if (expect)
123         {
124             if ((index < 0) || (in_[index] != x))
125                 throw new PgSQLProtocolException("Bad packet format");
126         }
127         return index;
128     }
129 
130     void skipLenEnc()
131     {
132         auto header = eat!ubyte;
133         if (header >= 0xfb)
134         {
135             switch(header)
136             {
137                 case 0xfb:
138                     return;
139                 case 0xfc:
140                     skip(2);
141                     return;
142                 case 0xfd:
143                     skip(3);
144                     return;
145                 case 0xfe:
146                     skip(8);
147                     return;
148                 default:
149                     throw new PgSQLProtocolException("Bad packet format");
150             }
151         }
152     }
153 
154     ulong eatLenEnc()
155     {
156         auto header = eat!ubyte;
157         if (header < 0xfb)
158             return header;
159 
160         ulong lo;
161         ulong hi;
162 
163         switch(header)
164         {
165             case 0xfb:
166                 return 0;
167             case 0xfc:
168                 return eat!ushort;
169             case 0xfd:
170                 lo = eat!ubyte;
171                 hi = eat!ushort;
172                 return lo | (hi << 8);
173             case 0xfe:
174                 lo = eat!uint;
175                 hi = eat!uint;
176                 return lo | (hi << 32);
177             default:
178                 throw new PgSQLProtocolException("Bad packet format");
179         }
180     }
181 
182     auto get() const
183     {
184         return in_;
185     }
186 
187     auto remaining() const
188     {
189         return in_.length;
190     }
191 
192     bool empty() const
193     {
194         return in_.length == 0;
195     }
196 
197 protected:
198 
199     ubyte[]* buffer_;
200     ubyte[] in_;
201     ubyte type_;
202 }
203 
204 struct OutputPacket
205 {
206     @disable this();
207 
208     this(ubyte[]* buffer)
209     {
210         buffer_ = buffer;
211         implicit_ = 4;
212         out_ = buffer_.ptr + 4;
213     }
214 
215     this(ubyte type, ubyte[]* buffer)
216     {
217         buffer_ = buffer;
218         implicit_ = 5;
219         if (buffer_.length < implicit_)
220             buffer_.length = implicit_;
221         *buffer_.ptr = type;
222         out_ = buffer_.ptr + implicit_;
223     }
224 
225     void putz(const(char)[] x)
226     {
227         put(x);
228         put!ubyte(0);
229     }
230 
231     pragma(inline, true) void put(T)(T x)
232     {
233         put(offset_, x);
234     }
235 
236     pragma(inline, true) void put(T)(size_t offset, T x) if (!isArray!T)
237     {
238         grow(offset, T.sizeof);
239 
240         *(cast(T*)(out_ + offset)) = host(x);
241         offset_ = max(offset + T.sizeof, offset_);
242     }
243 
244     void put(T)(size_t offset, T x) if (isArray!T)
245     {
246         alias ValueType = Unqual!(typeof(T.init[0]));
247 
248         grow(offset, ValueType.sizeof * x.length);
249 
250         static if (ValueType.sizeof == 1)
251         {
252             (cast(ValueType*)(out_ + offset))[0..x.length] = x;
253         }
254         else
255         {
256             auto pout = cast(ValueType*)(out_ + offset);
257             foreach (ref y; x)
258                 *pout++ = host(y);
259         }
260         offset_ = max(offset + (ValueType.sizeof * x.length), offset_);
261     }
262 
263     void putLenEnc(ulong x)
264     {
265         if (x < 0xfb)
266         {
267             put!ubyte(cast(ubyte)x);
268         }
269         else if (x <= ushort.max)
270         {
271             put!ubyte(0xfc);
272             put!ushort(cast(ushort)x);
273         }
274         else if (x <= (uint.max >> 8))
275         {
276             put!ubyte(0xfd);
277             put!ubyte(cast(ubyte)(x));
278             put!ushort(cast(ushort)(x >> 8));
279         }
280         else
281         {
282             put!ubyte(0xfe);
283             put!uint(cast(uint)x);
284             put!uint(cast(uint)(x >> 32));
285         }
286     }
287 
288     size_t marker(T)() if (!isArray!T)
289     {
290         grow(offset_, T.sizeof);
291 
292         auto place = offset_;
293         offset_ += T.sizeof;
294         return place;
295     }
296 
297     size_t marker(T)(size_t count) if (isArray!T)
298     {
299         alias ValueType = Unqual!(typeof(T.init[0]));
300         grow(offset_, ValueType.sizeof * x.length);
301 
302         auto place = offset_;
303         offset_ += (ValueType.sizeof * x.length);
304         return place;
305     }
306 
307     void finalize()
308     {
309         if ((offset_ + implicit_) > int.max)
310             throw new PgSQLConnectionException("Packet size exceeds 2^31");
311         *(cast(uint*)(buffer_.ptr + implicit_ - 4)) = host(cast(uint)(4 + offset_));
312     }
313 
314     void finalize(ubyte)
315     {
316         finalize();
317     }
318 
319     void finalize(ubyte seq, size_t extra)
320     {
321         finalize();
322     }
323 
324     void reset()
325     {
326         offset_ = 0;
327     }
328 
329     void reserve(size_t size)
330     {
331         (*buffer_).length = max((*buffer_).length, implicit_ + size);
332         out_ = buffer_.ptr + implicit_;
333     }
334 
335     void fill(ubyte x, size_t size)
336     {
337         grow(offset_, size);
338         out_[offset_..offset_ + size] = 0;
339         offset_ += size;
340     }
341 
342     size_t length() const
343     {
344         return offset_;
345     }
346 
347     bool empty() const
348     {
349         return offset_ == 0;
350     }
351 
352     const(ubyte)[] get() const
353     {
354         return (*buffer_)[0..implicit_ + offset_];
355     }
356 
357 protected:
358 
359     void grow(size_t offset, size_t size)
360     {
361         auto requested = implicit_ + offset + size;
362         if (requested > buffer_.length)
363         {
364             auto capacity = max(128, (*buffer_).capacity);
365             while (capacity < requested)
366                 capacity <<= 1;
367             buffer_.length = capacity;
368             out_ = buffer_.ptr + implicit_;
369         }
370     }
371 
372     ubyte[]* buffer_;
373 
374     ubyte* out_;
375     size_t offset_;
376     size_t implicit_;
377 }