1 module database.mysql.packet;
2 
3 import std.algorithm;
4 import std.traits;
5 
6 import database.mysql.exception;
7 
8 struct InputPacket
9 {
10     @disable this();
11 
12     this(ubyte[]* buffer)
13     {
14         buffer_ = buffer;
15         in_ = *buffer_;
16     }
17 
18     T peek(T)() if (!isArray!T)
19     {
20         assert(T.sizeof <= in_.length);
21         return *(cast(T*)in_.ptr);
22     }
23 
24     T eat(T)() if (!isArray!T)
25     {
26         assert(T.sizeof <= in_.length);
27         auto ptr = cast(T*)in_.ptr;
28         in_ = in_[T.sizeof..$];
29         return *ptr;
30     }
31 
32     T peek(T)(size_t count) if (isArray!T)
33     {
34         alias ValueType = typeof(Type.init[0]);
35 
36         assert(ValueType.sizeof * count <= in_.length);
37         auto ptr = cast(ValueType*)in_.ptr;
38         return ptr[0..count];
39     }
40 
41     T eat(T)(size_t count) if (isArray!T)
42     {
43         alias ValueType = typeof(T.init[0]);
44 
45         assert(ValueType.sizeof * count <= in_.length);
46         auto ptr = cast(ValueType*)in_.ptr;
47         in_ = in_[ValueType.sizeof * count..$];
48         return ptr[0..count];
49     }
50 
51     void expect(T)(T x)
52     {
53         if (x != eat!T)
54             throw new MySQLProtocolException("Bad packet format");
55     }
56 
57     void skip(size_t count)
58     {
59         if (in_.length == 0) return;
60         assert(count <= in_.length);
61         in_ = in_[count..$];
62     }
63 
64     auto countUntil(ubyte x, bool expect)
65     {
66         auto index = in_.countUntil(x);
67 
68         if (expect)
69         {
70             if ((index < 0) || (in_[index] != x))
71                 throw new MySQLProtocolException("Bad packet format");
72         }
73 
74         return index;
75     }
76 
77     void skipLenEnc()
78     {
79         auto header = eat!ubyte;
80         if (header >= 0xfb)
81         {
82             switch(header)
83             {
84                 case 0xfb:
85                     return;
86                 case 0xfc:
87                     skip(2);
88                     return;
89                 case 0xfd:
90                     skip(3);
91                     return;
92                 case 0xfe:
93                     skip(8);
94                     return;
95                 default:
96                     throw new MySQLProtocolException("Bad packet format");
97             }
98         }
99     }
100 
101     ulong eatLenEnc()
102     {
103         auto header = eat!ubyte;
104         if (header < 0xfb)
105             return header;
106 
107         ulong lo;
108         ulong hi;
109 
110         switch(header)
111         {
112             case 0xfb:
113                 return 0;
114             case 0xfc:
115                 return eat!ushort;
116             case 0xfd:
117                 lo = eat!ubyte;
118                 hi = eat!ushort;
119                 return lo | (hi << 8);
120             case 0xfe:
121                 lo = eat!uint;
122                 hi = eat!uint;
123                 return lo | (hi << 32);
124             default:
125                 throw new MySQLProtocolException("Bad packet format");
126         }
127     }
128 
129     auto remaining() const
130     {
131         return in_.length;
132     }
133 
134     bool empty() const
135     {
136         return in_.length == 0;
137     }
138 
139 protected:
140 
141     ubyte[]* buffer_;
142     ubyte[] in_;
143 }
144 
145 struct OutputPacket
146 {
147     @disable this();
148 
149     this(ubyte[]* buffer)
150     {
151         buffer_ = buffer;
152         out_ = buffer_.ptr + 4;
153     }
154 
155     pragma(inline, true) void put(T)(T x)
156     {
157         put(offset_, x);
158     }
159 
160     void put(T)(size_t offset, T x) if (!isArray!T)
161     {
162         grow(offset, T.sizeof);
163 
164         *(cast(T*)(out_ + offset)) = x;
165         offset_ = max(offset + T.sizeof, offset_);
166     }
167 
168     void put(T)(size_t offset, T x) if (isArray!T)
169     {
170         alias ValueType = Unqual!(typeof(T.init[0]));
171 
172         grow(offset, ValueType.sizeof * x.length);
173 
174         (cast(ValueType*)(out_ + offset))[0..x.length] = x;
175         offset_ = max(offset + (ValueType.sizeof * x.length), offset_);
176     }
177 
178     void putLenEnc(ulong x)
179     {
180         if (x < 0xfb)
181         {
182             put!ubyte(cast(ubyte)x);
183         }
184         else if (x <= ushort.max)
185         {
186             put!ubyte(0xfc);
187             put!ushort(cast(ushort)x);
188         }
189         else if (x <= (uint.max >> 8))
190         {
191             put!ubyte(0xfd);
192             put!ubyte(cast(ubyte)(x));
193             put!ushort(cast(ushort)(x >> 8));
194         }
195         else
196         {
197             put!ubyte(0xfe);
198             put!uint(cast(uint)x);
199             put!uint(cast(uint)(x >> 32));
200         }
201     }
202 
203     size_t marker(T)() if (!isArray!T)
204     {
205         grow(offset_, T.sizeof);
206 
207         auto place = offset_;
208         offset_ += T.sizeof;
209         return place;
210     }
211 
212     size_t marker(T)(size_t count) if (isArray!T)
213     {
214         alias ValueType = Unqual!(typeof(T.init[0]));
215         grow(offset_, ValueType.sizeof * x.length);
216 
217         auto place = offset_;
218         offset_ += (ValueType.sizeof * x.length);
219         return place;
220     }
221 
222     void finalize(ubyte seq)
223     {
224         if (offset_ >=  0xffffff)
225             throw new MySQLConnectionException("Packet size exceeds 2^24");
226         uint length = cast(uint)offset_;
227         uint header = cast(uint)((offset_ & 0xffffff) | (seq << 24));
228         *(cast(uint*)buffer_.ptr) = header;
229     }
230 
231     void finalize(ubyte seq, size_t extra)
232     {
233         if (offset_ + extra >= 0xffffff)
234             throw new MySQLConnectionException("Packet size exceeds 2^24");
235         uint length = cast(uint)(offset_ + extra);
236         uint header = cast(uint)((length & 0xffffff) | (seq << 24));
237         *(cast(uint*)buffer_.ptr) = header;
238     }
239 
240     void reset()
241     {
242         offset_ = 0;
243     }
244 
245     void reserve(size_t size)
246     {
247         (*buffer_).length = max((*buffer_).length, 4 + size);
248         out_ = buffer_.ptr + 4;
249     }
250 
251     void fill(ubyte x, size_t size)
252     {
253         grow(offset_, size);
254         out_[offset_..offset_ + size] = 0;
255         offset_ += size;
256     }
257 
258     size_t length() const
259     {
260         return offset_;
261     }
262 
263     bool empty() const
264     {
265         return offset_ == 0;
266     }
267 
268     const(ubyte)[] get() const
269     {
270         return (*buffer_)[0..4 + offset_];
271     }
272 
273 protected:
274 
275     void grow(size_t offset, size_t size)
276     {
277         auto requested = 4 + offset + size;
278 
279         if (requested > buffer_.length)
280         {
281             auto capacity = max(128, (*buffer_).capacity);
282             while (capacity < requested)
283             {
284                 capacity <<= 1;
285             }
286 
287             buffer_.length = capacity;
288             out_ = buffer_.ptr + 4;
289         }
290     }
291 
292     ubyte[]* buffer_;
293     ubyte* out_;
294     size_t offset_;
295 }