1 module database.mysql.connection;
2 
3 import std.algorithm;
4 import std.array;
5 import std.conv : to;
6 import std.string;
7 import std.traits;
8 import std.uni : sicmp;
9 import std.utf : decode, UseReplacementDchar;
10 import std.datetime;
11 
12 import database.mysql.exception;
13 import database.mysql.packet;
14 import database.mysql.protocol;
15 import database.mysql.type;
16 import database.mysql.socket;
17 import database.mysql.row;
18 import database.mysql.appender;
19 
20 immutable CapabilityFlags DefaultClientCaps = CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_LONG_FLAG |
21 CapabilityFlags.CLIENT_CONNECT_WITH_DB | CapabilityFlags.CLIENT_PROTOCOL_41 | CapabilityFlags.CLIENT_SECURE_CONNECTION | CapabilityFlags.CLIENT_SESSION_TRACK;
22 
23 struct ConnectionStatus
24 {
25     ulong affected;
26     ulong matched;
27     ulong changed;
28     ulong lastInsertId;
29     ushort flags;
30     ushort error;
31     ushort warnings;
32 }
33 
34 private struct ConnectionSettings
35 {
36     this(const(char)[] connectionString)
37     {
38         auto remaining = connectionString;
39 
40         auto indexValue = remaining.indexOf("=");
41         while (!remaining.empty) {
42             auto indexValueEnd = remaining.indexOf(";", indexValue);
43             if (indexValueEnd <= 0)
44                 indexValueEnd = remaining.length;
45 
46             auto name = strip(remaining[0..indexValue]);
47             auto value = strip(remaining[indexValue+1..indexValueEnd]);
48 
49             switch (name)
50             {
51                 case "host":
52                     host = value;
53                     break;
54                 case "user":
55                     user = value;
56                     break;
57                 case "pwd":
58                     pwd = value;
59                     break;
60                 case "db":
61                     db = value;
62                     break;
63                 case "port":
64                     port = to!ushort(value);
65                     break;
66                 default:
67                     throw new MySQLException(format("Bad connection string: %s", connectionString));
68             }
69 
70             if (indexValueEnd == remaining.length)
71                 return;
72 
73             remaining  = remaining[indexValueEnd + 1..$];
74             indexValue = remaining.indexOf("=");
75         }
76 
77         throw new MySQLException(format("Bad connection string: %s", connectionString));
78     }
79 
80     CapabilityFlags caps = DefaultClientCaps;
81 
82     const(char)[] host;
83     const(char)[] user;
84     const(char)[] pwd;
85     const(char)[] db;
86     ushort port = 3306;
87 }
88 
89 private struct ServerInfo
90 {
91     const(char)[] versionString;
92     ubyte protocol;
93     ubyte charSet;
94     ushort status;
95     uint connection;
96     uint caps;
97 }
98 
99 private struct PreparedStatement
100 {
101     uint id;
102     uint params;
103 }
104 
105 class Connection
106 {
107     this(string connectionString, CapabilityFlags caps = DefaultClientCaps)
108     {
109         settings_ = ConnectionSettings(connectionString);
110         settings_.caps = caps | CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_PROTOCOL_41;
111 
112         connect();
113     }
114 
115     this(const(char)[] host, const(char)[] user, const(char)[] pwd, const(char)[] db, ushort port = 3306)
116     {
117         this(host, user, pwd, db, port, DefaultClientCaps);
118     }
119 
120     this(const(char)[] host, const(char)[] user, const(char)[] pwd, const(char)[] db, ushort port = 3306, CapabilityFlags caps = DefaultClientCaps)
121     {
122         settings_.host = host;
123         settings_.user = user;
124         settings_.pwd = pwd;
125         settings_.db = db;
126         settings_.port = port;
127         settings_.caps = caps | CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_PROTOCOL_41;
128 
129         connect();
130     }
131 
132     void use(const(char)[] db)
133     {
134         send(Commands.COM_INIT_DB, db);
135         eatStatus(retrieve());
136 
137         if ((caps_ & CapabilityFlags.CLIENT_SESSION_TRACK) == 0)
138         {
139             schema_.length = db.length;
140             schema_[] = db[];
141         }
142     }
143 
144     void ping()
145     {
146         send(Commands.COM_PING);
147         eatStatus(retrieve());
148     }
149 
150     void refresh()
151     {
152         send(Commands.COM_REFRESH);
153         eatStatus(retrieve());
154     }
155 
156     void reset()
157     {
158         send(Commands.COM_RESET_CONNECTION);
159         eatStatus(retrieve());
160     }
161 
162     const(char)[] statistics()
163     {
164         send(Commands.COM_STATISTICS);
165 
166         auto answer = retrieve();
167         return answer.eat!(const(char)[])(answer.remaining);
168     }
169 
170     const(char)[] schema() const
171     {
172         return schema_;
173     }
174 
175     auto prepare(const(char)[] sql)
176     {
177         send(Commands.COM_STMT_PREPARE, sql);
178 
179         auto answer = retrieve();
180 
181         if (answer.peek!ubyte != StatusPackets.OK_Packet)
182             eatStatus(answer);
183 
184         answer.expect!ubyte(0);
185 
186         auto id = answer.eat!uint;
187         auto columns = answer.eat!ushort;
188         auto params = answer.eat!ushort;
189         answer.expect!ubyte(0);
190 
191         auto warnings = answer.eat!ushort;
192 
193         if (params)
194         {
195             foreach (i; 0..params)
196                 skipColumnDef(retrieve(), Commands.COM_STMT_PREPARE);
197 
198             eatEOF(retrieve());
199         }
200 
201         if (columns)
202         {
203             foreach (i; 0..columns)
204                 skipColumnDef(retrieve(), Commands.COM_STMT_PREPARE);
205 
206             eatEOF(retrieve());
207         }
208 
209         return PreparedStatement(id, params);
210     }
211 
212     void executeNoPrepare(Args...)(const(char)[] sql, Args args)
213     {
214         query(sql, args);
215     }
216 
217     void execute(Args...)(const(char)[] sql, Args args)
218     {
219         //scope(failure) close_();
220 
221         PreparedStatement stmt;
222         if (sql in clientPreparedCaches_)
223         {
224             stmt = clientPreparedCaches_[sql];
225         }
226         else
227         {
228             stmt = prepare(sql);
229             clientPreparedCaches_[sql] = stmt;
230         }
231 
232         execute(stmt, args);
233         // closePreparedStatement(stmt);
234     }
235 
236     void set(T)(const(char)[] variable, T value)
237     {
238         query("set session ?=?", MySQLFragment(variable), value);
239     }
240 
241     const(char)[] get(const(char)[] variable)
242     {
243         const(char)[] result;
244         query("show session variables like ?", variable, (MySQLRow row) {
245             result = row[1].peek!(const(char)[]).dup;
246         });
247 
248         return result;
249     }
250 
251     void startTransaction()
252     {
253         if (inTransaction)
254         {
255             throw new MySQLErrorException("MySQL does not support nested transactions - commit or rollback before starting a new transaction");
256         }
257 
258         query("start transaction");
259 
260         assert(inTransaction);
261     }
262 
263     void commit()
264     {
265         if (!inTransaction)
266         {
267             throw new MySQLErrorException("No active transaction");
268         }
269 
270         query("commit");
271 
272         assert(!inTransaction);
273     }
274 
275     void rollback()
276     {
277         if (!connected)
278         {
279             return;
280         }
281 
282         if ((status_.flags & StatusFlags.SERVER_STATUS_IN_TRANS) == 0)
283         {
284             throw new MySQLErrorException("No active transaction");
285         }
286 
287         query("rollback");
288 
289         assert(!inTransaction);
290     }
291 
292     @property bool inTransaction() const
293     {
294         return connected && (status_.flags & StatusFlags.SERVER_STATUS_IN_TRANS);
295     }
296 
297     void execute(Args...)(PreparedStatement stmt, Args args)
298     {
299         //scope(failure) close_();
300 
301         ensureConnected();
302 
303         seq_ = 0;
304         auto packet = OutputPacket(&out_);
305         packet.put!ubyte(Commands.COM_STMT_EXECUTE);
306         packet.put!uint(stmt.id);
307         packet.put!ubyte(Cursors.CURSOR_TYPE_READ_ONLY);
308         packet.put!uint(1);
309 
310         static if (args.length == 0)
311         {
312             enum shouldDiscard = true;
313         }
314         else
315         {
316             enum shouldDiscard = !isCallable!(args[args.length - 1]);
317         }
318 
319         enum argCount = shouldDiscard ? args.length : (args.length - 1);
320 
321         if (!argCount && stmt.params)
322         {
323             throw new MySQLErrorException(format("Wrong number of parameters for query. Got 0 but expected %d.", stmt.params));
324         }
325 
326         static if (argCount)
327         {
328             enum NullsCapacity = 128; // must be power of 2
329             ubyte[NullsCapacity >> 3] nulls;
330             size_t bitsOut;
331             size_t indexArg;
332 
333             foreach(i, arg; args[0..argCount])
334             {
335                 const auto index = (indexArg >> 3) & (NullsCapacity - 1);
336                 const auto bit = indexArg & 7;
337 
338                 static if (is(typeof(arg) == typeof(null)))
339                 {
340                     nulls[index] = nulls[index] | (1 << bit);
341                     ++indexArg;
342                 }
343                 else static if (is(Unqual!(typeof(arg)) == MySQLValue))
344                 {
345                     if (arg.isNull)
346                     {
347                         nulls[index] = nulls[index] | (1 << bit);
348                     }
349                     ++indexArg;
350                 }
351                 else static if (isArray!(typeof(arg)) && !isSomeString!(typeof(arg)))
352                 {
353                     indexArg += arg.length;
354                 }
355                 else
356                 {
357                     ++indexArg;
358                 }
359 
360                 auto finishing = (i == argCount - 1);
361                 auto remaining = indexArg - bitsOut;
362 
363                 if (finishing || (remaining >= NullsCapacity))
364                 {
365                     while (remaining)
366                     {
367                         auto bits = min(remaining, NullsCapacity);
368 
369                         packet.put(nulls[0..(bits + 7) >> 3]);
370                         bitsOut += bits;
371                         nulls[] = 0;
372 
373                         remaining = (indexArg - bitsOut);
374                         if (!remaining || (!finishing && (remaining < NullsCapacity)))
375                         {
376                             break;
377                         }
378                     }
379                 }
380             }
381             packet.put!ubyte(1);
382 
383             if (indexArg != stmt.params)
384             {
385                 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", indexArg, stmt.params));
386             }
387 
388             foreach (arg; args[0..argCount])
389             {
390                 static if (is(typeof(arg) == enum))
391                 {
392                     putValueType(packet, cast(OriginalType!(Unqual!(typeof(arg))))arg);
393                 }
394                 else
395                 {
396                     putValueType(packet, arg);
397                 }
398             }
399 
400             foreach (arg; args[0..argCount])
401             {
402                 static if (is(typeof(arg) == enum))
403                 {
404                     putValue(packet, cast(OriginalType!(Unqual!(typeof(arg))))arg);
405                 }
406                 else
407                 {
408                     putValue(packet, arg);
409                 }
410             }
411         }
412 
413         packet.finalize(seq_);
414         ++seq_;
415 
416         socket_.write(packet.get());
417 
418         auto answer = retrieve();
419         if (isStatus(answer))
420         {
421             eatStatus(answer);
422         }
423         else
424         {
425             static if (!shouldDiscard)
426             {
427                 resultSet(answer, stmt.id, Commands.COM_STMT_EXECUTE, args[args.length - 1]);
428             }
429             else
430             {
431                 discardAll(answer, Commands.COM_STMT_EXECUTE);
432             }
433         }
434     }
435 
436     void closePreparedStatement(PreparedStatement stmt)
437     {
438         uint[1] data = [ stmt.id ];
439         send(Commands.COM_STMT_CLOSE, data);
440     }
441 
442     @property ulong lastInsertId() const
443     {
444         return status_.lastInsertId;
445     }
446 
447     @property ulong affected() const
448     {
449         return cast(size_t)status_.affected;
450     }
451 
452     @property ulong matched() const
453     {
454         return cast(size_t)status_.matched;
455     }
456 
457     @property ulong changed() const
458     {
459         return cast(size_t)status_.changed;
460     }
461 
462     @property size_t warnings() const
463     {
464         return status_.warnings;
465     }
466 
467     @property size_t error() const
468     {
469         return status_.error;
470     }
471 
472     @property const(char)[] status() const
473     {
474         return info_;
475     }
476 
477     @property bool connected() const
478     {
479         return socket_.connected;
480     }
481 
482     void close()
483     {
484         clearClientPreparedCache();
485         socket_.close();
486     }
487 
488     void reuse()
489     {
490         ensureConnected();
491 
492         if (inTransaction)
493         {
494             rollback;
495         }
496 
497         if (settings_.db.length && (settings_.db != schema_))
498         {
499             use(settings_.db);
500         }
501     }
502 
503     @property void trace(bool enable)
504     {
505         trace_ = enable;
506     }
507 
508     @property bool trace()
509     {
510         return trace_;
511     }
512 
513 package:
514 
515     @property bool busy()
516     {
517         return busy_;
518     }
519 
520     @property void busy(bool value)
521     {
522         busy_ = value;
523 
524         if (!value)
525         {
526             clearClientPreparedCache();
527         }
528     }
529 
530     @property bool pooled()
531     {
532         return pooled_;
533     }
534 
535     @property void pooled(bool value)
536     {
537         pooled_ = value;
538     }
539 
540     @property DateTime releaseTime()
541     {
542         return releaseTime_;
543     }
544 
545     @property void releaseTime(DateTime value)
546     {
547         releaseTime_ = value;
548     }
549 
550 private:
551 
552     void close_()
553     {
554         close();
555     }
556 
557     void query(Args...)(const(char)[] sql, Args args)
558     {
559         //scope(failure) close_();
560 
561         static if (args.length == 0)
562         {
563             enum shouldDiscard = true;
564         }
565         else
566         {
567             enum shouldDiscard = !isCallable!(args[args.length - 1]);
568         }
569 
570         enum argCount = shouldDiscard ? args.length : (args.length - 1);
571 
572         static if (argCount)
573         {
574             auto querySQL =  prepareSQL(sql, args[0..argCount]);
575         }
576         else
577         {
578             auto querySQL =  sql;
579         }
580 
581         send(Commands.COM_QUERY, querySQL);
582 
583         auto answer = retrieve();
584         if (isStatus(answer))
585         {
586             eatStatus(answer);
587         }
588         else
589         {
590             static if (!shouldDiscard)
591             {
592                 resultSetText(answer, Commands.COM_QUERY, args[args.length - 1]);
593             }
594             else
595             {
596                 discardAll(answer, Commands.COM_QUERY);
597             }
598         }
599     }
600 
601     void connect()
602     {
603         socket_.connect(settings_.host, settings_.port);
604 
605         seq_ = 0;
606         eatHandshake(retrieve());
607         clearClientPreparedCache();
608     }
609 
610     void send(T)(Commands cmd, T[] data)
611     {
612         send(cmd, cast(ubyte*)data.ptr, data.length * T.sizeof);
613     }
614 
615     void send(Commands cmd, ubyte* data = null, size_t length = 0)
616     {
617         ensureConnected();
618 
619         seq_ = 0;
620         auto header = OutputPacket(&out_);
621         header.put!ubyte(cmd);
622         header.finalize(seq_, length);
623         ++seq_;
624 
625         socket_.write(header.get());
626         if (length)
627         {
628             socket_.write(data[0..length]);
629         }
630     }
631 
632     void ensureConnected()
633     {
634         if (!socket_.connected)
635         {
636             connect();
637         }
638     }
639 
640     void clearClientPreparedCache()
641     {
642         if (clientPreparedCaches_.length == 0)
643         {
644             return;
645         }
646 
647         foreach (p; clientPreparedCaches_)
648         {
649             closePreparedStatement(p);
650         }
651 
652         clientPreparedCaches_.clear();
653     }
654 
655     bool isStatus(InputPacket packet)
656     {
657         auto id = packet.peek!ubyte;
658 
659         switch (id)
660         {
661             case StatusPackets.ERR_Packet:
662             case StatusPackets.OK_Packet:
663                 return true;
664             default:
665                 return false;
666         }
667     }
668 
669     void check(InputPacket packet, bool smallError = false)
670     {
671         auto id = packet.peek!ubyte;
672 
673         switch (id)
674         {
675             case StatusPackets.ERR_Packet:
676             case StatusPackets.OK_Packet:
677                 eatStatus(packet, smallError);
678                 break;
679             default:
680                 break;
681         }
682     }
683 
684     InputPacket retrieve()
685     {
686         //scope(failure) close_();
687 
688         ubyte[4] header;
689         socket_.read(header);
690 
691         auto len = header[0] | (header[1] << 8) | (header[2] << 16);
692         auto seq = header[3];
693 
694         if (seq != seq_)
695         {
696             throw new MySQLConnectionException("Out of order packet received");
697         }
698 
699         ++seq_;
700 
701         in_.length = len;
702         socket_.read(in_);
703 
704         if (in_.length != len)
705         {
706             throw new MySQLConnectionException("Wrong number of bytes read");
707         }
708 
709         return InputPacket(&in_);
710     }
711 
712     void eatHandshake(InputPacket packet)
713     {
714         //scope(failure) close_();
715 
716         check(packet, true);
717 
718         server_.protocol = packet.eat!ubyte;
719         server_.versionString = packet.eat!(const(char)[])(packet.countUntil(0, true)).dup;
720         packet.skip(1);
721 
722         server_.connection = packet.eat!uint;
723 
724         const auto authLengthStart = 8;
725         size_t authLength = authLengthStart;
726 
727         ubyte[256] auth;
728         auth[0..authLength] = packet.eat!(ubyte[])(authLength);
729 
730         packet.expect!ubyte(0);
731 
732         server_.caps = packet.eat!ushort;
733 
734         if (!packet.empty)
735         {
736             server_.charSet = packet.eat!ubyte;
737             server_.status = packet.eat!ushort;
738             server_.caps |= packet.eat!ushort << 16;
739             server_.caps |= CapabilityFlags.CLIENT_LONG_PASSWORD;
740 
741             if ((server_.caps & CapabilityFlags.CLIENT_PROTOCOL_41) == 0)
742             {
743                 throw new MySQLProtocolException("Server doesn't support protocol v4.1");
744             }
745 
746             if (server_.caps & CapabilityFlags.CLIENT_SECURE_CONNECTION)
747             {
748                 packet.skip(1);
749             }
750             else
751             {
752                 packet.expect!ubyte(0);
753             }
754 
755             packet.skip(10);
756 
757             authLength += packet.countUntil(0, true);
758             if (authLength > auth.length)
759             {
760                 throw new MySQLConnectionException("Bad packet format");
761             }
762 
763             auth[authLengthStart..authLength] = packet.eat!(ubyte[])(authLength - authLengthStart);
764 
765             packet.expect!ubyte(0);
766         }
767 
768         caps_ = cast(CapabilityFlags)(settings_.caps & server_.caps);
769 
770         ubyte[20] token;
771         {
772             import std.digest.sha : sha1Of;
773 
774             auto pass = sha1Of(cast(const(ubyte)[])settings_.pwd);
775             token = sha1Of(auth[0..authLength], sha1Of(pass));
776 
777             foreach (i; 0..20)
778             {
779                 token[i] = token[i] ^ pass[i];
780             }
781         }
782 
783         auto reply = OutputPacket(&out_);
784 
785         reply.reserve(64 + settings_.user.length + settings_.pwd.length + settings_.db.length);
786 
787         reply.put!uint(caps_);
788         reply.put!uint(1);
789         reply.put!ubyte(45);
790         reply.fill(0, 23);
791 
792         reply.put(settings_.user);
793         reply.put!ubyte(0);
794 
795         if (settings_.pwd.length)
796         {
797             if (caps_ & CapabilityFlags.CLIENT_SECURE_CONNECTION)
798             {
799                 reply.put!ubyte(token.length);
800                 reply.put(token);
801             }
802             else
803             {
804                 reply.put(token);
805                 reply.put!ubyte(0);
806             }
807         }
808         else
809         {
810             reply.put!ubyte(0);
811         }
812 
813         if ((settings_.db.length || schema_.length) && (caps_ & CapabilityFlags.CLIENT_CONNECT_WITH_DB))
814         {
815             if (schema_.length)
816             {
817                 reply.put(schema_);
818             }
819             else
820             {
821                 reply.put(settings_.db);
822 
823                 schema_.length = settings_.db.length;
824                 schema_[] = settings_.db[];
825             }
826         }
827 
828         reply.put!ubyte(0);
829 
830         reply.finalize(seq_);
831         ++seq_;
832 
833         socket_.write(reply.get());
834 
835         eatStatus(retrieve());
836     }
837 
838     void eatStatus(InputPacket packet, bool smallError = false)
839     {
840         auto id = packet.eat!ubyte;
841 
842         switch (id)
843         {
844             case StatusPackets.OK_Packet:
845                 status_.matched = 0;
846                 status_.changed = 0;
847                 status_.affected = packet.eatLenEnc();
848                 status_.lastInsertId = packet.eatLenEnc();
849                 status_.flags = packet.eat!ushort;
850                 if (caps_ & CapabilityFlags.CLIENT_PROTOCOL_41)
851                 {
852                     status_.warnings = packet.eat!ushort;
853                 }
854                 status_.error = 0;
855                 info([]);
856 
857                 if (caps_ & CapabilityFlags.CLIENT_SESSION_TRACK)
858                 {
859                     if (!packet.empty)
860                     {
861                         info(packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()));
862 
863                         if (status_.flags & StatusFlags.SERVER_SESSION_STATE_CHANGED)
864                         {
865                             packet.skipLenEnc();
866                             while (!packet.empty())
867                             {
868                                 final switch (packet.eat!ubyte()) with (SessionStateType)
869                                 {
870                                     case SESSION_TRACK_SCHEMA:
871                                         packet.skipLenEnc();
872                                         schema_.length = cast(size_t)packet.eatLenEnc();
873                                         schema_[] = packet.eat!(const(char)[])(schema_.length);
874                                         break;
875                                     case SESSION_TRACK_SYSTEM_VARIABLES:
876                                     case SESSION_TRACK_GTIDS:
877                                     case SESSION_TRACK_STATE_CHANGE:
878                                     case SESSION_TRACK_TRANSACTION_STATE:
879                                     case SESSION_TRACK_TRANSACTION_CHARACTERISTICS:
880                                         packet.skip(cast(size_t)packet.eatLenEnc());
881                                         break;
882                                 }
883                             }
884                         }
885                     }
886                 }
887                 else
888                 {
889                     info(packet.eat!(const(char)[])(packet.remaining));
890                 }
891 
892                 import std.regex : matchFirst, regex;
893                 static matcher = regex(`\smatched:\s*(\d+)\s+changed:\s*(\d+)`, `i`);
894                 auto matches = matchFirst(info_, matcher);
895 
896                 if (!matches.empty)
897                 {
898                     status_.matched = matches[1].to!ulong;
899                     status_.changed = matches[2].to!ulong;
900                 }
901 
902                 break;
903             case StatusPackets.EOF_Packet:
904                 status_.affected = 0;
905                 status_.changed = 0;
906                 status_.matched = 0;
907                 status_.error = 0;
908                 status_.warnings = packet.eat!ushort;
909                 status_.flags = packet.eat!ushort;
910                 info([]);
911 
912                 break;
913             case StatusPackets.ERR_Packet:
914                 status_.affected = 0;
915                 status_.changed = 0;
916                 status_.matched = 0;
917                 //status_.flags = 0;//[shove]
918                 status_.flags &= StatusFlags.SERVER_STATUS_IN_TRANS;
919                 status_.warnings = 0;
920                 status_.error = packet.eat!ushort;
921                 if (!smallError)
922                 {
923                     packet.skip(6);
924                 }
925                 info(packet.eat!(const(char)[])(packet.remaining));
926 
927                 switch(status_.error)
928                 {
929                     case ErrorCodes.ER_DUP_ENTRY_WITH_KEY_NAME:
930                     case ErrorCodes.ER_DUP_ENTRY:
931                         throw new MySQLDuplicateEntryException(info_.idup);
932                     case ErrorCodes.ER_DATA_TOO_LONG_FOR_COL:
933                         throw new MySQLDataTooLongException(info_.idup);
934                     case ErrorCodes.ER_DEADLOCK_FOUND:
935                         throw new MySQLDeadlockFoundException(info_.idup);
936                     case ErrorCodes.ER_TABLE_DOESNT_EXIST:
937                         throw new MySQLTableDoesntExistException(info_.idup);
938                     case ErrorCodes.ER_LOCK_WAIT_TIMEOUT:
939                         throw new MySQLLockWaitTimeoutException(info_.idup);
940                     default:
941                         version(development)
942                         {
943                             // On dev show the query together with the error message
944                             throw new MySQLErrorException(format("[err:%s] %s - %s", status_.error, info_, sql_.data));
945                         }
946                         else
947                         {
948                             throw new MySQLErrorException(format("[err:%s] %s", status_.error, info_));
949                         }
950                 }
951             default:
952                 throw new MySQLProtocolException("Unexpected packet format");
953         }
954     }
955 
956     void info(const(char)[] value)
957     {
958         info_.length = value.length;
959         info_[0..$] = value;
960     }
961 
962     void skipColumnDef(InputPacket packet, Commands cmd)
963     {
964         packet.skip(cast(size_t)packet.eatLenEnc());    // catalog
965         packet.skip(cast(size_t)packet.eatLenEnc());    // schema
966         packet.skip(cast(size_t)packet.eatLenEnc());    // table
967         packet.skip(cast(size_t)packet.eatLenEnc());    // original_table
968         packet.skip(cast(size_t)packet.eatLenEnc());    // name
969         packet.skip(cast(size_t)packet.eatLenEnc());    // original_name
970         packet.skipLenEnc();                            // next_length
971         packet.skip(10); // 2 + 4 + 1 + 2 + 1           // charset, length, type, flags, decimals
972         packet.expect!ushort(0);
973 
974         if (cmd == Commands.COM_FIELD_LIST)
975         {
976             packet.skip(cast(size_t)packet.eatLenEnc());// default values
977         }
978     }
979 
980     void columnDef(InputPacket packet, Commands cmd, ref MySQLColumn def)
981     {
982         packet.skip(cast(size_t)packet.eatLenEnc());    // catalog
983         packet.skip(cast(size_t)packet.eatLenEnc());    // schema
984         packet.skip(cast(size_t)packet.eatLenEnc());    // table
985         packet.skip(cast(size_t)packet.eatLenEnc());    // original_table
986         auto len = cast(size_t)packet.eatLenEnc();
987         columns_ ~= packet.eat!(const(char)[])(len);
988         def.name = columns_[$ - len..$];
989         packet.skip(cast(size_t)packet.eatLenEnc());    // original_name
990         packet.skipLenEnc();                            // next_length
991         packet.skip(2);                                 // charset
992         def.length = packet.eat!uint;
993         def.type = cast(ColumnTypes)packet.eat!ubyte;
994         def.flags = packet.eat!ushort;
995         def.decimals = packet.eat!ubyte;
996 
997         packet.expect!ushort(0);
998 
999         if (cmd == Commands.COM_FIELD_LIST)
1000         {
1001             packet.skip(cast(size_t)packet.eatLenEnc());// default values
1002         }
1003     }
1004 
1005     void columnDefs(size_t count, Commands cmd, ref MySQLColumn[] defs)
1006     {
1007         defs.length = count;
1008         foreach (i; 0..count)
1009         {
1010             columnDef(retrieve(), cmd, defs[i]);
1011         }
1012     }
1013 
1014     bool callHandler(RowHandler)(RowHandler handler, size_t, MySQLHeader, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 1) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLRow))
1015     {
1016         static if (is(ReturnType!(RowHandler) == void))
1017         {
1018             handler(row);
1019             return true;
1020         }
1021         else
1022         {
1023             return handler(row); // return type must be bool
1024         }
1025     }
1026 
1027     bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow))
1028     {
1029         static if (is(ReturnType!(RowHandler) == void))
1030         {
1031             handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row);
1032             return true;
1033         }
1034         else
1035         {
1036             return handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); // return type must be bool
1037         }
1038     }
1039 
1040     bool callHandler(RowHandler)(RowHandler handler, size_t, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow))
1041     {
1042         static if (is(ReturnType!(RowHandler) == void))
1043         {
1044             handler(header, row);
1045             return true;
1046         }
1047         else
1048         {
1049             return handler(header, row); // return type must be bool
1050         }
1051     }
1052 
1053     bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 3) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[2] == MySQLRow))
1054     {
1055         static if (is(ReturnType!(RowHandler) == void))
1056         {
1057             handler(i, header, row);
1058             return true;
1059         }
1060         else
1061         {
1062             return handler(i, header, row); // return type must be bool
1063         }
1064     }
1065 
1066     void resultSetRow(InputPacket packet, MySQLHeader header, ref MySQLRow row)
1067     {
1068         assert(row.columns.length == header.length);
1069 
1070         packet.expect!ubyte(0);
1071         auto nulls = packet.eat!(ubyte[])((header.length + 2 + 7) >> 3);
1072 
1073         foreach (i, ref column; header)
1074         {
1075             const auto index = (i + 2) >> 3; // bit offset of 2
1076             const auto bit = (i + 2) & 7;
1077 
1078             if ((nulls[index] & (1 << bit)) == 0)
1079             {
1080                 eatValue(packet, column, row.get_(i));
1081             }
1082             else
1083             {
1084                 auto signed = (column.flags & FieldFlags.UNSIGNED_FLAG) == 0;
1085                 row.get_(i) = MySQLValue(column.name, ColumnTypes.MYSQL_TYPE_NULL, signed, null, 0);
1086             }
1087         }
1088         assert(packet.empty);
1089     }
1090 
1091     void resultSet(RowHandler)(InputPacket packet, uint stmt, Commands cmd, RowHandler handler)
1092     {
1093         columns_.length = 0;
1094 
1095         auto columns = cast(size_t)packet.eatLenEnc();
1096         columnDefs(columns, cmd, header_);
1097         row_.header_(header_);
1098 
1099         auto status = retrieve();
1100         if (status.peek!ubyte == StatusPackets.ERR_Packet)
1101         {
1102             eatStatus(status);
1103         }
1104 
1105         size_t index;
1106         auto statusFlags = eatEOF(status);
1107         if (statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS)
1108         {
1109             uint[2] data = [ stmt, 4096 ]; // todo: make setting - rows per fetch
1110             while (statusFlags & (StatusFlags.SERVER_STATUS_CURSOR_EXISTS | StatusFlags.SERVER_MORE_RESULTS_EXISTS))
1111             {
1112                 send(Commands.COM_STMT_FETCH, data);
1113 
1114                 auto answer = retrieve();
1115                 if (answer.peek!ubyte == StatusPackets.ERR_Packet)
1116                 {
1117                     eatStatus(answer);
1118                 }
1119 
1120                 auto row = answer.empty ? retrieve() : answer;
1121                 while (true)
1122                 {
1123                     if (row.peek!ubyte == StatusPackets.EOF_Packet)
1124                     {
1125                         statusFlags = eatEOF(row);
1126                         break;
1127                     }
1128 
1129                     resultSetRow(row, header_, row_);
1130                     if (!callHandler(handler, index++, header_, row_))
1131                     {
1132                         discardUntilEOF(retrieve());
1133                         statusFlags = 0;
1134                         break;
1135                     }
1136                     row = retrieve();
1137                 }
1138             }
1139         }
1140         else
1141         {
1142             while (true)
1143             {
1144                 auto row = retrieve();
1145                 if (row.peek!ubyte == StatusPackets.EOF_Packet)
1146                 {
1147                     eatEOF(row);
1148                     break;
1149                 }
1150 
1151                 resultSetRow(row, header_, row_);
1152                 if (!callHandler(handler, index++, header_, row_))
1153                 {
1154                     discardUntilEOF(retrieve());
1155                     break;
1156                 }
1157             }
1158         }
1159     }
1160 
1161     void resultSetRowText(InputPacket packet, MySQLHeader header, ref MySQLRow row)
1162     {
1163         assert(row.columns.length == header.length);
1164 
1165         foreach(i, ref column; header)
1166         {
1167             if (packet.peek!ubyte != 0xfb)
1168             {
1169                 eatValueText(packet, column, row.get_(i));
1170             }
1171             else
1172             {
1173                 packet.skip(1);
1174                 auto signed = (column.flags & FieldFlags.UNSIGNED_FLAG) == 0;
1175                 row.get_(i) = MySQLValue(column.name, ColumnTypes.MYSQL_TYPE_NULL, signed, null, 0);
1176             }
1177         }
1178         assert(packet.empty);
1179     }
1180 
1181     void resultSetText(RowHandler)(InputPacket packet, Commands cmd, RowHandler handler)
1182     {
1183         columns_.length = 0;
1184 
1185         auto columns = cast(size_t)packet.eatLenEnc();
1186         columnDefs(columns, cmd, header_);
1187         row_.header_(header_);
1188 
1189         eatEOF(retrieve());
1190 
1191         size_t index;
1192         while (true) {
1193             auto row = retrieve();
1194             if (row.peek!ubyte == StatusPackets.EOF_Packet)
1195             {
1196                 eatEOF(row);
1197                 break;
1198             } else if (row.peek!ubyte == StatusPackets.ERR_Packet)
1199             {
1200                 eatStatus(row);
1201                 break;
1202             }
1203 
1204             resultSetRowText(row, header_, row_);
1205             if (!callHandler(handler, index++, header_, row_))
1206             {
1207                 discardUntilEOF(retrieve());
1208                 break;
1209             }
1210         }
1211     }
1212 
1213     void discardAll(InputPacket packet, Commands cmd)
1214     {
1215         auto columns = cast(size_t)packet.eatLenEnc();
1216         columnDefs(columns, cmd, header_);
1217 
1218         auto statusFlags = eatEOF(retrieve());
1219         if ((statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) == 0)
1220         {
1221             while (true)
1222             {
1223                 auto row = retrieve();
1224                 if (row.peek!ubyte == StatusPackets.EOF_Packet)
1225                 {
1226                     eatEOF(row);
1227                     break;
1228                 }
1229             }
1230         }
1231     }
1232 
1233     void discardUntilEOF(InputPacket packet)
1234     {
1235         while (true)
1236         {
1237             if (packet.peek!ubyte == StatusPackets.EOF_Packet)
1238             {
1239                 eatEOF(packet);
1240                 break;
1241             }
1242             packet = retrieve();
1243         }
1244     }
1245 
1246     auto eatEOF(InputPacket packet)
1247     {
1248         auto id = packet.eat!ubyte;
1249         if (id != StatusPackets.EOF_Packet)
1250         {
1251             throw new MySQLProtocolException("Unexpected packet format");
1252         }
1253 
1254         status_.error = 0;
1255         status_.warnings = packet.eat!ushort();
1256         status_.flags = packet.eat!ushort();
1257         info([]);
1258 
1259         return status_.flags;
1260     }
1261 
1262     auto estimateArgs(Args...)(ref size_t estimated, Args args)
1263     {
1264         size_t argCount;
1265 
1266         foreach(i, arg; args)
1267         {
1268             static if (is(typeof(arg) == typeof(null)))
1269             {
1270                 ++argCount;
1271                 estimated += 4;
1272             }
1273             else static if (is(Unqual!(typeof(arg)) == MySQLValue))
1274             {
1275                 ++argCount;
1276                 final switch(arg.type) with (ColumnTypes)
1277                 {
1278                     case MYSQL_TYPE_NULL:
1279                         estimated += 4;
1280                         break;
1281                     case MYSQL_TYPE_TINY:
1282                         estimated += 4;
1283                         break;
1284                     case MYSQL_TYPE_YEAR:
1285                     case MYSQL_TYPE_SHORT:
1286                         estimated += 6;
1287                         break;
1288                     case MYSQL_TYPE_INT24:
1289                     case MYSQL_TYPE_LONG:
1290                         estimated += 6;
1291                         break;
1292                     case MYSQL_TYPE_LONGLONG:
1293                         estimated += 8;
1294                         break;
1295                     case MYSQL_TYPE_FLOAT:
1296                         estimated += 8;
1297                         break;
1298                     case MYSQL_TYPE_DOUBLE:
1299                         estimated += 8;
1300                         break;
1301                     case MYSQL_TYPE_SET:
1302                     case MYSQL_TYPE_ENUM:
1303                     case MYSQL_TYPE_VARCHAR:
1304                     case MYSQL_TYPE_VAR_STRING:
1305                     case MYSQL_TYPE_STRING:
1306                     case MYSQL_TYPE_JSON:
1307                     case MYSQL_TYPE_NEWDECIMAL:
1308                     case MYSQL_TYPE_DECIMAL:
1309                     case MYSQL_TYPE_TINY_BLOB:
1310                     case MYSQL_TYPE_MEDIUM_BLOB:
1311                     case MYSQL_TYPE_LONG_BLOB:
1312                     case MYSQL_TYPE_BLOB:
1313                     case MYSQL_TYPE_BIT:
1314                     case MYSQL_TYPE_GEOMETRY:
1315                         estimated += 2 + arg.peek!(const(char)[]).length;
1316                         break;
1317                     case MYSQL_TYPE_TIME:
1318                     case MYSQL_TYPE_TIME2:
1319                         estimated += 18;
1320                         break;
1321                     case MYSQL_TYPE_DATE:
1322                     case MYSQL_TYPE_NEWDATE:
1323                     case MYSQL_TYPE_DATETIME:
1324                     case MYSQL_TYPE_DATETIME2:
1325                     case MYSQL_TYPE_TIMESTAMP:
1326                     case MYSQL_TYPE_TIMESTAMP2:
1327                         estimated += 20;
1328                         break;
1329                 }
1330             }
1331             else static if (isArray!(typeof(arg)) && !isSomeString!(typeof(arg)))
1332             {
1333                 argCount += arg.length;
1334                 estimated += arg.length * 6;
1335             }
1336             else static if (isSomeString!(typeof(arg)) || is(Unqual!(typeof(arg)) == MySQLRawString) || is(Unqual!(typeof(arg)) == MySQLFragment) || is(Unqual!(typeof(arg)) == MySQLBinary))
1337             {
1338                 ++argCount;
1339                 estimated += 2 + arg.length;
1340             }
1341             else
1342             {
1343                 ++argCount;
1344                 estimated += 6;
1345             }
1346         }
1347 
1348         return argCount;
1349     }
1350 
1351     auto prepareSQL(Args...)(const(char)[] sql, Args args)
1352     {
1353         auto estimated = sql.length;
1354         auto argCount = estimateArgs(estimated, args);
1355 
1356         sql_.clear;
1357         sql_.reserve(max(8192, estimated));
1358 
1359         alias AppendFunc = bool function(ref Appender!(char[]), ref const(char)[] sql, ref size_t, const(void)*) @safe pure nothrow;
1360         AppendFunc[Args.length] funcs;
1361         const(void)*[Args.length] addrs;
1362 
1363         foreach (i, Arg; Args)
1364         {
1365             static if (is(Arg == enum))
1366             {
1367                 funcs[i] = () @trusted { return cast(AppendFunc)&appendNextValue!(OriginalType!Arg); }();
1368                 addrs[i] = (ref x) @trusted { return cast(const void*)&x; }(cast(OriginalType!(Unqual!Arg))args[i]);
1369             }
1370             else
1371             {
1372                 funcs[i] = () @trusted { return cast(AppendFunc)&appendNextValue!(Arg); }();
1373                 addrs[i] = (ref x) @trusted { return cast(const void*)&x; }(args[i]);
1374             }
1375         }
1376 
1377         size_t indexArg;
1378         foreach (i; 0..Args.length)
1379         {
1380             if (!funcs[i](sql_, sql, indexArg, addrs[i]))
1381             {
1382                 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", argCount, indexArg));
1383             }
1384         }
1385 
1386         finishCopy(sql_, sql, argCount, indexArg);
1387 
1388         return sql_.data;
1389     }
1390 
1391     void finishCopy(ref Appender!(char[]) app, ref const(char)[] sql, size_t argCount, size_t indexArg)
1392     {
1393         if (copyUpToNext(sql_, sql))
1394         {
1395             ++indexArg;
1396 
1397             while (copyUpToNext(sql_, sql))
1398             {
1399                 ++indexArg;
1400             }
1401 
1402             throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", argCount, indexArg));
1403         }
1404     }
1405 
1406     Socket socket_;
1407     MySQLHeader header_;
1408     MySQLRow row_;
1409     char[] columns_;
1410     char[] info_;
1411     char[] schema_;
1412     ubyte[] in_;
1413     ubyte[] out_;
1414     ubyte seq_;
1415     Appender!(char[]) sql_;
1416 
1417     CapabilityFlags caps_;
1418     ConnectionStatus status_;
1419     ConnectionSettings settings_;
1420     ServerInfo server_;
1421 
1422     // For tracing queries
1423     bool trace_;
1424 
1425     PreparedStatement[const(char)[]] clientPreparedCaches_;
1426 
1427     bool busy_;
1428     bool pooled_;
1429     DateTime releaseTime_;
1430 }
1431 
1432 private auto copyUpToNext(ref Appender!(char[]) app, ref const(char)[] sql)
1433 {
1434     size_t offset;
1435     dchar quote = '\0';
1436 
1437     while (offset < sql.length)
1438     {
1439         auto ch = decode!(UseReplacementDchar.no)(sql, offset);
1440         switch (ch)
1441         {
1442             case '?':
1443                 if (!quote)
1444                 {
1445                     app.put(sql[0..offset - 1]);
1446                     sql = sql[offset..$];
1447                     return true;
1448                 }
1449                 else
1450                 {
1451                     goto default;
1452                 }
1453             case '\'':
1454             case '\"':
1455             case '`':
1456                 if (quote == ch)
1457                 {
1458                     quote = '\0';
1459                 }
1460                 else if (!quote)
1461                 {
1462                     quote = ch;
1463                 }
1464                 goto default;
1465             case '\\':
1466                 if (quote && (offset < sql.length))
1467                     decode!(UseReplacementDchar.no)(sql, offset);
1468                 goto default;
1469             default:
1470                 break;
1471         }
1472     }
1473     app.put(sql[0..offset]);
1474     sql = sql[offset..$];
1475 
1476     return false;
1477 }
1478 
1479 private bool appendNextValue(T)(ref Appender!(char[]) app, ref const(char)[] sql, ref size_t indexArg, const(void)* arg)
1480 {
1481     static if (isArray!T && !isSomeString!(OriginalType!T))
1482     {
1483         foreach (i, ref v; *cast(T*)arg)
1484         {
1485             if (copyUpToNext(app, sql))
1486             {
1487                 appendValue(app, v);
1488                 ++indexArg;
1489             }
1490             else
1491             {
1492                 return false;
1493             }
1494         }
1495     }
1496     else
1497     {
1498         if (copyUpToNext(app, sql))
1499         {
1500             appendValue(app, *cast(T*)arg);
1501             ++indexArg;
1502         }
1503         else
1504         {
1505             return false;
1506         }
1507     }
1508 
1509     return true;
1510 }