1 module mysql.connection;
2 
3 import std.array;
4 import std.functional;
5 import std.string;
6 import std.traits;
7 
8 public import mysql.exception;
9 import mysql.packet;
10 public import mysql.protocol;
11 public import mysql.type;
12 
13 
14 immutable CapabilityFlags DefaultClientCaps = CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_LONG_FLAG |
15     CapabilityFlags.CLIENT_CONNECT_WITH_DB | CapabilityFlags.CLIENT_PROTOCOL_41 | CapabilityFlags.CLIENT_SECURE_CONNECTION;
16 
17 
18 struct ConnectionSettings {
19     CapabilityFlags caps = DefaultClientCaps;
20 
21     const(char)[] host;
22     const(char)[] user;
23     const(char)[] pwd;
24     const(char)[] db;
25     ushort port = 3306;
26 }
27 
28 
29 struct ConnectionStatus {
30     CapabilityFlags caps = cast(CapabilityFlags)0;
31 
32     ulong affected = 0;
33     ulong insertID = 0;
34     ushort flags = 0;
35     ushort error = 0;
36     ushort warnings = 0;
37 }
38 
39 
40 struct ServerInfo {
41     const(char)[] versionString;
42     ubyte protocol;
43     ubyte charSet;
44     ushort status;
45     uint connection;
46     uint caps;
47 }
48 
49 
50 struct PreparedStatement {
51 package:
52     uint id;    // todo: investigate if it's really necessary to close statements explicitly
53     uint params;
54 }
55 
56 
57 struct Connection(SocketType) {
58     void connect(string connectionString) {
59         connectionSettings(connectionString);
60         connect();
61     }
62 
63     void connect(const(char)[] host, ushort port, const(char)[] user, const(char)[] pwd, const(char)[] db, CapabilityFlags caps = DefaultClientCaps) {
64         settings_.host = host;
65         settings_.user = user;
66         settings_.pwd = pwd;
67         settings_.db = db;
68         settings_.port = port;
69         settings_.caps = caps | CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_PROTOCOL_41;
70 
71         connect();
72     }
73 
74     void use(const(char)[] db) {
75         send(Commands.COM_INIT_DB, db);
76         status(retrieve());
77     }
78 
79     void ping() {
80         send(Commands.COM_PING);
81         status(retrieve());
82     }
83 
84     void refresh() {
85         send(Commands.COM_REFRESH);
86         status(retrieve());
87     }
88 
89     void reset() {
90         send(Commands.COM_RESET_CONNECTION);
91         status(retrieve());
92     }
93 
94     const(char)[] statistics() {
95         send(Commands.COM_STATISTICS);
96         
97         auto answer = retrieve();
98         return answer.eat!(const(char)[])(answer.remaining);
99     }
100 
101     auto prepare(const(char)[] sql) {
102         send(Commands.COM_STMT_PREPARE, sql);
103 
104         auto answer = retrieve();
105         check(answer);
106 
107         answer.expect!ubyte(0);
108         
109         auto id = answer.eat!uint;
110         auto columns = answer.eat!ushort;
111         auto params = answer.eat!ushort;
112         answer.expect!ubyte(0);
113 
114         auto warnings = answer.eat!ushort;
115 
116         if (params) {
117             MySQLColumn def;
118             foreach (i; 0..params)
119                 columnDef(retrieve(), Commands.COM_STMT_PREPARE, def);
120 
121             skipEOF(retrieve());
122         }
123 
124         if (columns) {
125             MySQLColumn def;
126             foreach (i; 0..columns)
127                 columnDef(retrieve(), Commands.COM_STMT_PREPARE, def);
128 
129             skipEOF(retrieve());
130         }
131 
132         return PreparedStatement(id, params);
133     }
134 
135     void execute(Args...)(const(char)[] stmt, Args args) {
136         auto id = prepare(stmt);
137         execute(id, args);
138         close(id);
139     }
140 
141     void execute(Args...)(PreparedStatement stmt, Args args) {
142         ensureConnected();
143 
144         seq_ = 0;
145         auto packet = OutputPacket(&out_);
146         packet.put!ubyte(Commands.COM_STMT_EXECUTE);
147         packet.put!uint(stmt.id);
148         packet.put!ubyte(Cursors.CURSOR_TYPE_READ_ONLY);
149         packet.put!uint(1);
150 
151         static if (args.length == 0) {
152             enum shouldDiscard = true;
153         } else {
154             enum shouldDiscard = !isCallable!(args[args.length - 1]);
155         }
156 
157         enum argCount = shouldDiscard ? args.length : (args.length - 1);
158 
159         if (argCount != stmt.params)
160             throw new MySQLErrorException("Wrong number of parameters for query");
161 
162         static if (argCount) {
163             ubyte[1024] nulls;
164             foreach(i, arg; args) {
165                 const auto index = i >> 3;
166                 const auto bit = i & 7;
167 
168                 static if (is(typeof(arg) == typeof(null))) {
169                     nulls[index] = nulls[index] | (1 << bit);
170                 }
171             }
172 
173             packet.put(nulls[0..((args.length + 7) >> 3)]);
174             packet.put!ubyte(1);
175 
176             foreach (arg; args[0..argCount])
177                 putValueType(packet, arg);
178 
179             foreach (arg; args[0..argCount]) {
180                 static if (!is(typeof(arg) == typeof(null))) {
181                     putValue(packet, arg);
182                 }
183             }
184         }
185 
186         packet.finalize(seq_);
187         ++seq_;
188 
189         socket_.write(packet.get());
190         
191         auto answer = retrieve();
192         if (isStatus(answer)) {
193             status(answer);
194         } else {
195             static if (!shouldDiscard) {
196                 resultSet(answer, stmt.id, Commands.COM_STMT_EXECUTE, args[args.length - 1]);
197             } else {
198                 discardAll(answer, Commands.COM_STMT_EXECUTE);
199             }
200         }
201     }
202 
203     void close(PreparedStatement stmt) {
204         uint[1] data = [ stmt.id ];
205         send(Commands.COM_STMT_CLOSE, data);
206     }
207 
208     ulong insertID() {
209         return cast(size_t)status_.insertID;
210     }
211 
212     ulong affected() {
213         return cast(size_t)status_.affected;
214     }
215 
216     size_t warnings() {
217         return status_.warnings;
218     }
219 
220     size_t error() {
221         return status_.error;
222     }
223 
224     const(char)[] status() {
225         return info_;
226     }
227 
228     void disconnect() {
229         socket_.close();
230     }
231 
232     ~this() {
233         disconnect();
234     }
235    
236 private:
237     void connect() {
238         socket_.connect(settings_.host, settings_.port);
239 
240         seq_ = 0;
241         handshake(retrieve());
242     }
243 
244     void send(T)(Commands cmd, T[] data) {
245         send(cmd, cast(ubyte*)data.ptr, data.length * T.sizeof);
246     }
247 
248     void send(Commands cmd, ubyte* data = null, size_t length = 0) {
249         if(!socket_.connected) {
250             connect();
251         } else {
252             seq_ = 0;
253         }
254 
255         auto header = OutputPacket(&out_);
256         header.put!ubyte(cmd);
257         header.finalize(seq_, length);
258         ++seq_;
259 
260         socket_.write(header.get());
261         if (length)
262             socket_.write(data[0..length]);
263     }
264 
265     void ensureConnected() {
266         if(!socket_.connected)
267             connect();
268     }
269 
270     bool isStatus(InputPacket packet) {
271         auto id = packet.peek!ubyte;
272         switch (id) {
273             case StatusPackets.ERR_Packet:
274             case StatusPackets.OK_Packet:
275                 return 1;
276             default:
277                 return false;
278         }
279     }
280 
281     void check(InputPacket packet) {
282         auto id = packet.peek!ubyte;
283         switch (id) {
284             case StatusPackets.ERR_Packet:
285             case StatusPackets.OK_Packet:
286                 status(packet);
287                 break;
288             default:
289                 break;
290         }
291     }
292 
293     InputPacket retrieve() {
294         scope(failure) disconnect();
295 
296         ubyte[4] header;
297         socket_.read(header);
298 
299         auto len = header[0] | (header[1] << 8) | (header[2] << 16);
300         auto seq = header[3];
301 
302         if (seq != seq_)
303             throw new MySQLConnectionException("Out of order packet received");
304 
305         ++seq_;
306 
307         in_.length = len;
308         socket_.read(in_);
309 
310         if (in_.length != len)
311             throw new MySQLConnectionException("Wrong number of bytes read");
312 
313         return InputPacket(&in_);
314     }
315 
316     void handshake(InputPacket packet) {
317         scope(failure) disconnect();
318 
319         server_.protocol = packet.eat!ubyte;
320         server_.versionString = packet.eat!(const(char)[])(packet.countUntil(0, true));
321         packet.skip(1);
322 
323         server_.connection = packet.eat!uint;
324 
325         const auto authLengthStart = 8;
326         size_t authLength = authLengthStart;
327 
328         ubyte[256] auth;
329         auth[0..authLength] = packet.eat!(ubyte[])(authLength);
330 
331         packet.expect!ubyte(0);
332 
333         server_.caps = packet.eat!ushort;
334 
335         if (!packet.empty) {
336             server_.charSet = packet.eat!ubyte;
337             server_.status = packet.eat!ushort;
338             server_.caps |= packet.eat!ushort << 16;
339             server_.caps |= CapabilityFlags.CLIENT_LONG_PASSWORD;
340 
341             if ((server_.caps & CapabilityFlags.CLIENT_PROTOCOL_41) == 0)
342                 throw new MySQLProtocolException("Server doesn't support protocol v4.1");
343 
344             if (server_.caps & CapabilityFlags.CLIENT_SECURE_CONNECTION) {
345                 packet.skip(1);
346             } else {
347                 packet.expect!ubyte(0);
348             }
349 
350             packet.skip(10);
351 
352             authLength += packet.countUntil(0, true);
353             if (authLength > auth.length)
354                 throw new MySQLConnectionException("Bad packet format");
355 
356             auth[authLengthStart..authLength] = packet.eat!(ubyte[])(authLength - authLengthStart);
357 
358             packet.expect!ubyte(0);
359         }
360 
361         ubyte[20] token;
362         {
363             import std.digest.sha;
364 
365             auto pass = sha1Of(cast(const(ubyte)[])settings_.pwd);
366             token = sha1Of(pass);
367 
368             SHA1 sha1;
369             sha1.start();
370             sha1.put(auth[0..authLength]);
371             sha1.put(token);
372             token = sha1.finish();
373 
374             foreach (i; 0..20)
375                 token[i] = token[i] ^ pass[i];
376         }
377 
378         status_.caps = cast(CapabilityFlags)(settings_.caps & server_.caps);
379 
380         auto reply = OutputPacket(&out_);
381         reply.reserve(64 + settings_.user.length + settings_.pwd.length + settings_.db.length);
382         
383         reply.put!uint(status_.caps);
384         reply.put!uint(1);
385         reply.put!ubyte(33);
386         reply.fill(0, 23);
387 
388         reply.put(settings_.user);
389         reply.put!ubyte(0);
390 
391         if (settings_.pwd.length) {
392             if (status_.caps & CapabilityFlags.CLIENT_SECURE_CONNECTION) {
393                 reply.put!ubyte(token.length);
394                 reply.put(token);
395             } else {
396                 reply.put(token);
397                 reply.put!ubyte(0);
398             }
399         } else {
400             reply.put!ubyte(0);
401         }
402 
403         if (settings_.db.length && (status_.caps & CapabilityFlags.CLIENT_CONNECT_WITH_DB)) {
404             reply.put(settings_.db);
405             reply.put!ubyte(0);
406         }
407 
408         reply.finalize(seq_);
409         ++seq_;
410 
411         socket_.write(reply.get());
412 
413         status(retrieve());
414     }
415 
416     void status(InputPacket packet) {
417         auto id = packet.eat!ubyte;
418 
419         switch (id) {
420         case StatusPackets.OK_Packet:
421             status_.error = 0;
422             status_.affected = packet.eatLenEnc();
423             status_.insertID = packet.eatLenEnc();
424             status_.flags = packet.eat!ushort;
425             status_.warnings = packet.eat!ushort;
426 
427             if (status_.caps & CapabilityFlags.CLIENT_SESSION_TRACK) {
428                 info(packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()));
429                 packet.skip(1);
430 
431                 if (status_.flags & StatusFlags.SERVER_SESSION_STATE_CHANGED) {
432                     packet.skip(cast(size_t)packet.eatLenEnc());
433                     packet.skip(1);
434                 }
435             } else if (!packet.empty) {
436                 info(packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()));
437             }
438             break;
439         case StatusPackets.EOF_Packet:
440             status_.warnings = packet.eat!ushort;
441             status_.flags = packet.eat!ushort;
442             info([]);
443             break;
444         case StatusPackets.ERR_Packet:
445             status_.flags = 0;
446             status_.error = packet.eat!ushort;
447             packet.skip(6);
448             info(packet.eat!(const(char)[])(packet.remaining));
449 
450             throw new MySQLErrorException(cast(string)info_);
451         default:
452             throw new MySQLProtocolException("Unexpected packet format");
453         }
454     }
455 
456     void info(const(char)[] value) {
457         info_.length = value.length;
458         info_[0..$] = value;
459     }
460 
461     void columnDef(InputPacket packet, Commands cmd, ref MySQLColumn def) {
462         auto catalog = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc());
463         auto schema = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc());
464         auto table = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc());
465         auto org_table = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc());
466         def.name = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()).idup; // todo: fix allocation
467         auto org_name = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc());
468         auto next_length = cast(size_t)packet.eatLenEnc();
469         auto char_set = packet.eat!ushort;
470         def.length = packet.eat!uint;
471         def.type = cast(ColumnTypes)packet.eat!ubyte;
472         def.flags = packet.eat!ushort;
473         def.decimals = packet.eat!ubyte;
474 
475         packet.expect!ushort(0);
476 
477         if (cmd == Commands.COM_FIELD_LIST) {
478             auto default_values = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc());
479         }
480     }
481 
482     auto columnDefs(size_t count, Commands cmd) {
483         header_.length = count;
484         foreach (i; 0..count)
485             columnDef(retrieve(), cmd, header_[i]);
486         return header_;
487     }
488 
489     void resultSetRow(InputPacket packet, Commands cmd, MySQLHeader header, MySQLRow row) {
490         assert(row.length == header.length);
491 
492         packet.expect!ubyte(0);
493         auto nulls = packet.eat!(ubyte[])((header.length + 2 + 7) >> 3);
494         foreach (i, column; header) {
495             const auto index = (i + 2) >> 3; // bit offset of 2
496             const auto bit = (i + 2) & 7;
497 
498             if ((nulls[index] & (1 << bit)) == 0) {
499                 row[i] = eatValue(packet, column);
500             } else {
501                 row[i].nullify();
502             }
503         }
504         assert(packet.empty);
505     }
506 
507     bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 1) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLRow)) {
508         static if (is(ReturnType!(RowHandler) == void)) {
509             handler(row);
510             return true;
511         } else {
512             return handler(row); // return type must be bool
513         }
514     }
515 
516     bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) {
517         static if (is(ReturnType!(RowHandler) == void)) {
518             handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row);
519             return true;
520         } else {
521             return handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); // return type must be bool
522         }
523     }
524 
525     bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) {
526         static if (is(ReturnType!(RowHandler) == void)) {
527             handler(header, row);
528             return true;
529         } else {
530             return handler(header, row); // return type must be bool
531         }
532     }
533 
534     bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 3) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[2] == MySQLRow)) {
535         static if (is(ReturnType!(RowHandler) == void)) {
536             handler(i, header, row);
537             return true;
538         } else {
539             return handler(i, header, row); // return type must be bool
540         }
541     }
542 
543     void resultSet(RowHandler)(InputPacket packet, uint stmt, Commands cmd, RowHandler handler) {
544         auto columns = cast(size_t)packet.eatLenEnc();
545         auto header = columnDefs(columns, cmd);
546         row_.length = columns;
547 
548         size_t index = 0;
549         auto statusFlags = skipEOF(retrieve());
550         if (statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) {
551             uint[2] data = [ stmt, 4096 ]; // todo: make setting - rows per fetch
552             while (statusFlags & (StatusFlags.SERVER_STATUS_CURSOR_EXISTS | StatusFlags.SERVER_MORE_RESULTS_EXISTS)) {
553                 send(Commands.COM_STMT_FETCH, data);
554 
555                 auto answer = retrieve();
556                 if (answer.peek!ubyte == StatusPackets.ERR_Packet)
557                     check(answer);
558 
559                 auto row = answer.empty ? retrieve() : answer;
560                 while (true) {
561                     if (row.peek!ubyte == StatusPackets.EOF_Packet) {
562                         statusFlags = skipEOF(row);
563                         break;
564                     }
565 
566                     resultSetRow(row, Commands.COM_STMT_FETCH, header, row_);
567                     if (!callHandler(handler, index++, header, row_)) {
568                         discardUntilEOF(retrieve());
569                         statusFlags = 0;
570                         break;
571                     }
572                     row = retrieve();
573                 }
574             }
575         } else {
576             auto row = retrieve();
577             while (true) {
578                 if (row.peek!ubyte == StatusPackets.EOF_Packet) {
579                     status(row);
580                     break;
581                 }
582 
583                 resultSetRow(row, cmd, header, row_);
584                 if (!callHandler(handler, index++, header, row_)) {
585                     discardUntilEOF(retrieve());
586                     break;
587                 }
588 
589                 row = retrieve();
590             }
591         }
592     }
593 
594     void discardAll(InputPacket packet, Commands cmd) {
595         auto columns = cast(size_t)packet.eatLenEnc();
596         auto defs = columnDefs(columns, cmd);
597 
598         auto statusFlags = skipEOF(retrieve());
599         if ((statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) == 0) {
600             while (true) {
601                 auto row = retrieve();
602                 if (row.peek!ubyte == StatusPackets.EOF_Packet) {
603                     status(row);
604                     break;
605                 }
606             }
607         }
608     }
609 
610     void discardUntilEOF(InputPacket packet) {
611         if (packet.peek!ubyte == StatusPackets.EOF_Packet) {
612             status(packet);
613             return;
614         } else {
615             while (true) {
616                 if (packet.peek!ubyte == StatusPackets.EOF_Packet) {
617                     status(packet);
618                     break;
619                 }
620                 packet = retrieve();
621             }
622         }
623     }
624 
625     auto skipEOF(InputPacket packet) {
626         auto id = packet.eat!ubyte;
627         if (id != StatusPackets.EOF_Packet)
628             throw new MySQLProtocolException("Unexpected packet format");
629         
630         packet.skip(2);
631         return packet.eat!ushort();
632     }
633 
634     void connectionSettings(const(char)[] connectionString) {
635         import std.conv;
636 
637         auto remaining = connectionString;
638 
639         auto indexValue = remaining.indexOf("=");
640         while (!remaining.empty) {
641             auto indexValueEnd = remaining.indexOf(";", indexValue);
642             if (indexValueEnd <= 0)
643                 indexValueEnd = remaining.length;
644 
645             auto name = strip(remaining[0..indexValue]);
646             auto value = strip(remaining[indexValue+1..indexValueEnd]);
647 
648             switch (name) {
649             case "host":
650                 settings_.host = value;
651                 break;
652             case "user":
653                 settings_.user = value;
654                 break;
655             case "pwd":
656                 settings_.pwd = value;
657                 break;
658             case "db":
659                 settings_.db = value;
660                 break;
661             case "port":
662                 settings_.port = to!ushort(value);
663                 break;
664             default:
665                 throw new MySQLException("Bad connection string: " ~ cast(string)connectionString);
666             }
667 
668             if (indexValueEnd == remaining.length)
669                 return;
670 
671             remaining = remaining[indexValueEnd+1..$];
672             indexValue = remaining.indexOf("=");
673         }
674 
675         throw new MySQLException("Bad connection string: " ~ cast(string)connectionString);
676     }
677 
678     SocketType socket_;
679     MySQLHeader header_;
680     MySQLRow row_;
681     char[] info_;
682     ubyte[] in_;
683     ubyte[] out_;
684     ubyte seq_ = 0;
685  
686     ConnectionStatus status_;
687     ConnectionSettings settings_;
688     ServerInfo server_;
689 }