1 module database.mysql.packet; 2 3 import std.algorithm; 4 import std.traits; 5 6 import database.mysql.exception; 7 8 struct InputPacket 9 { 10 @disable this(); 11 12 this(ubyte[]* buffer) 13 { 14 buffer_ = buffer; 15 in_ = *buffer_; 16 } 17 18 T peek(T)() if (!isArray!T) 19 { 20 assert(T.sizeof <= in_.length); 21 return *(cast(T*)in_.ptr); 22 } 23 24 T eat(T)() if (!isArray!T) 25 { 26 assert(T.sizeof <= in_.length); 27 auto ptr = cast(T*)in_.ptr; 28 in_ = in_[T.sizeof..$]; 29 return *ptr; 30 } 31 32 T peek(T)(size_t count) if (isArray!T) 33 { 34 alias ValueType = typeof(Type.init[0]); 35 36 assert(ValueType.sizeof * count <= in_.length); 37 auto ptr = cast(ValueType*)in_.ptr; 38 return ptr[0..count]; 39 } 40 41 T eat(T)(size_t count) if (isArray!T) 42 { 43 alias ValueType = typeof(T.init[0]); 44 45 assert(ValueType.sizeof * count <= in_.length); 46 auto ptr = cast(ValueType*)in_.ptr; 47 in_ = in_[ValueType.sizeof * count..$]; 48 return ptr[0..count]; 49 } 50 51 void expect(T)(T x) 52 { 53 if (x != eat!T) 54 throw new MySQLProtocolException("Bad packet format"); 55 } 56 57 void skip(size_t count) 58 { 59 if (in_.length == 0) return; 60 assert(count <= in_.length); 61 in_ = in_[count..$]; 62 } 63 64 auto countUntil(ubyte x, bool expect) 65 { 66 auto index = in_.countUntil(x); 67 68 if (expect) 69 { 70 if ((index < 0) || (in_[index] != x)) 71 throw new MySQLProtocolException("Bad packet format"); 72 } 73 74 return index; 75 } 76 77 void skipLenEnc() 78 { 79 auto header = eat!ubyte; 80 if (header >= 0xfb) 81 { 82 switch(header) 83 { 84 case 0xfb: 85 return; 86 case 0xfc: 87 skip(2); 88 return; 89 case 0xfd: 90 skip(3); 91 return; 92 case 0xfe: 93 skip(8); 94 return; 95 default: 96 throw new MySQLProtocolException("Bad packet format"); 97 } 98 } 99 } 100 101 ulong eatLenEnc() 102 { 103 auto header = eat!ubyte; 104 if (header < 0xfb) 105 return header; 106 107 ulong lo; 108 ulong hi; 109 110 switch(header) 111 { 112 case 0xfb: 113 return 0; 114 case 0xfc: 115 return eat!ushort; 116 case 0xfd: 117 lo = eat!ubyte; 118 hi = eat!ushort; 119 return lo | (hi << 8); 120 case 0xfe: 121 lo = eat!uint; 122 hi = eat!uint; 123 return lo | (hi << 32); 124 default: 125 throw new MySQLProtocolException("Bad packet format"); 126 } 127 } 128 129 auto remaining() const 130 { 131 return in_.length; 132 } 133 134 bool empty() const 135 { 136 return in_.length == 0; 137 } 138 139 protected: 140 141 ubyte[]* buffer_; 142 ubyte[] in_; 143 } 144 145 struct OutputPacket 146 { 147 @disable this(); 148 149 this(ubyte[]* buffer) 150 { 151 buffer_ = buffer; 152 out_ = buffer_.ptr + 4; 153 } 154 155 pragma(inline, true) void put(T)(T x) 156 { 157 put(offset_, x); 158 } 159 160 void put(T)(size_t offset, T x) if (!isArray!T) 161 { 162 grow(offset, T.sizeof); 163 164 *(cast(T*)(out_ + offset)) = x; 165 offset_ = max(offset + T.sizeof, offset_); 166 } 167 168 void put(T)(size_t offset, T x) if (isArray!T) 169 { 170 alias ValueType = Unqual!(typeof(T.init[0])); 171 172 grow(offset, ValueType.sizeof * x.length); 173 174 (cast(ValueType*)(out_ + offset))[0..x.length] = x; 175 offset_ = max(offset + (ValueType.sizeof * x.length), offset_); 176 } 177 178 void putLenEnc(ulong x) 179 { 180 if (x < 0xfb) 181 { 182 put!ubyte(cast(ubyte)x); 183 } 184 else if (x <= ushort.max) 185 { 186 put!ubyte(0xfc); 187 put!ushort(cast(ushort)x); 188 } 189 else if (x <= (uint.max >> 8)) 190 { 191 put!ubyte(0xfd); 192 put!ubyte(cast(ubyte)(x)); 193 put!ushort(cast(ushort)(x >> 8)); 194 } 195 else 196 { 197 put!ubyte(0xfe); 198 put!uint(cast(uint)x); 199 put!uint(cast(uint)(x >> 32)); 200 } 201 } 202 203 size_t marker(T)() if (!isArray!T) 204 { 205 grow(offset_, T.sizeof); 206 207 auto place = offset_; 208 offset_ += T.sizeof; 209 return place; 210 } 211 212 size_t marker(T)(size_t count) if (isArray!T) 213 { 214 alias ValueType = Unqual!(typeof(T.init[0])); 215 grow(offset_, ValueType.sizeof * x.length); 216 217 auto place = offset_; 218 offset_ += (ValueType.sizeof * x.length); 219 return place; 220 } 221 222 void finalize(ubyte seq) 223 { 224 if (offset_ >= 0xffffff) 225 throw new MySQLConnectionException("Packet size exceeds 2^24"); 226 uint length = cast(uint)offset_; 227 uint header = cast(uint)((offset_ & 0xffffff) | (seq << 24)); 228 *(cast(uint*)buffer_.ptr) = header; 229 } 230 231 void finalize(ubyte seq, size_t extra) 232 { 233 if (offset_ + extra >= 0xffffff) 234 throw new MySQLConnectionException("Packet size exceeds 2^24"); 235 uint length = cast(uint)(offset_ + extra); 236 uint header = cast(uint)((length & 0xffffff) | (seq << 24)); 237 *(cast(uint*)buffer_.ptr) = header; 238 } 239 240 void reset() 241 { 242 offset_ = 0; 243 } 244 245 void reserve(size_t size) 246 { 247 (*buffer_).length = max((*buffer_).length, 4 + size); 248 out_ = buffer_.ptr + 4; 249 } 250 251 void fill(ubyte x, size_t size) 252 { 253 grow(offset_, size); 254 out_[offset_..offset_ + size] = 0; 255 offset_ += size; 256 } 257 258 size_t length() const 259 { 260 return offset_; 261 } 262 263 bool empty() const 264 { 265 return offset_ == 0; 266 } 267 268 const(ubyte)[] get() const 269 { 270 return (*buffer_)[0..4 + offset_]; 271 } 272 273 protected: 274 275 void grow(size_t offset, size_t size) 276 { 277 auto requested = 4 + offset + size; 278 279 if (requested > buffer_.length) 280 { 281 auto capacity = max(128, (*buffer_).capacity); 282 while (capacity < requested) 283 { 284 capacity <<= 1; 285 } 286 287 buffer_.length = capacity; 288 out_ = buffer_.ptr + 4; 289 } 290 } 291 292 ubyte[]* buffer_; 293 ubyte* out_; 294 size_t offset_; 295 }