1 module mysql.connection; 2 3 4 import std.algorithm; 5 import std.array; 6 import std.conv : to; 7 import std..string; 8 import std.traits; 9 import std.uni : sicmp; 10 import std.utf : decode, UseReplacementDchar; 11 12 import mysql.appender; 13 public import mysql.exception; 14 import mysql.packet; 15 import mysql.protocol; 16 import mysql.ssl; 17 public import mysql.type; 18 19 20 immutable CapabilityFlags DefaultClientCaps = CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_LONG_FLAG | 21 CapabilityFlags.CLIENT_CONNECT_WITH_DB | CapabilityFlags.CLIENT_PROTOCOL_41 | CapabilityFlags.CLIENT_SECURE_CONNECTION | CapabilityFlags.CLIENT_SESSION_TRACK; 22 23 24 struct ConnectionStatus { 25 ulong affected; 26 ulong matched; 27 ulong changed; 28 ulong insertID; 29 ushort flags; 30 ushort error; 31 ushort warnings; 32 } 33 34 35 struct ConnectionSettings { 36 this(const(char)[] connectionString) { 37 parse(connectionString); 38 } 39 40 void parse(const(char)[] connectionString) { 41 auto remaining = connectionString; 42 43 auto indexValue = remaining.indexOf("="); 44 while (!remaining.empty) { 45 auto indexValueEnd = remaining.indexOf(";", indexValue); 46 if (indexValueEnd <= 0) 47 indexValueEnd = remaining.length; 48 49 auto name = strip(remaining[0..indexValue]); 50 auto value = strip(remaining[indexValue+1..indexValueEnd]); 51 52 switch (name) { 53 case "host": 54 host = value; 55 break; 56 case "user": 57 user = value; 58 break; 59 case "pwd": 60 pwd = value; 61 break; 62 case "db": 63 db = value; 64 break; 65 case "port": 66 port = to!ushort(value); 67 break; 68 case "ssl": 69 switch (value) { 70 case "0": 71 case "no": 72 case "false": 73 break; 74 case "require": 75 case "required": 76 ssl.enforce = true; 77 goto case "yes"; 78 case "1": 79 case "yes": 80 case "true": 81 caps |= CapabilityFlags.CLIENT_SSL; 82 break; 83 default: 84 throw new MySQLException(format("Bad value for 'ssl' on connection string: %s", value)); 85 } 86 break; 87 case "ssl_rootcert": 88 ssl.rootCertFile = value; 89 break; 90 case "ssl_hostname": 91 ssl.hostName = value; 92 break; 93 case "ssl_ciphers": 94 ssl.ciphers = value; 95 break; 96 case "ssl_version": 97 switch (value) with (SSLConfig.Version) { 98 case "any": 99 ssl.sslVersion = any; 100 break; 101 case "ssl3": 102 ssl.sslVersion = ssl3; 103 break; 104 case "tls1": 105 ssl.sslVersion = tls1; 106 break; 107 case "tls1_1": 108 ssl.sslVersion = tls1_1; 109 break; 110 case "tls1_2": 111 ssl.sslVersion = tls1_2; 112 break; 113 case "dtls1": 114 ssl.sslVersion = dtls1; 115 break; 116 default: 117 throw new MySQLException(format("Bad value for 'ssl_version' on connection string: %s", value)); 118 } 119 break; 120 case "ssl_validate": 121 switch (value) with (SSLConfig.Validate) { 122 case "basic": 123 ssl.validate = basic; 124 break; 125 case "trust": 126 ssl.validate = trust; 127 break; 128 case "identity": 129 ssl.validate = identity; 130 break; 131 default: 132 throw new MySQLException(format("Bad value for 'ssl_validate' on connection string: %s", value)); 133 } 134 break; 135 default: 136 throw new MySQLException(format("Bad connection string: %s", connectionString)); 137 } 138 139 if (indexValueEnd == remaining.length) 140 return; 141 142 remaining = remaining[indexValueEnd+1..$]; 143 indexValue = remaining.indexOf("="); 144 } 145 146 throw new MySQLException(format("Bad connection string: %s", connectionString)); 147 } 148 149 CapabilityFlags caps = DefaultClientCaps; 150 151 const(char)[] host; 152 const(char)[] user; 153 const(char)[] pwd; 154 const(char)[] db; 155 ushort port = 3306; 156 157 SSLConfig ssl; 158 } 159 160 161 private struct ServerInfo { 162 const(char)[] versionString; 163 ubyte protocol; 164 ubyte charSet; 165 ushort status; 166 uint connection; 167 uint caps; 168 } 169 170 171 @property string placeholders(size_t x, bool parens = true) { 172 if (x) { 173 auto app = appender!string; 174 if (parens) { 175 app.reserve(x + x - 1); 176 177 app.put('('); 178 foreach (i; 0..x - 1) 179 app.put("?,"); 180 app.put('?'); 181 app.put(')'); 182 } else { 183 app.reserve(x + x + 1); 184 185 foreach (i; 0..x - 1) 186 app.put("?,"); 187 app.put('?'); 188 } 189 return app.data; 190 } 191 192 return null; 193 } 194 195 196 @property string placeholders(T)(T x, bool parens = true) if (is(typeof(() { auto y = x.length; }))) { 197 return x.length.placeholders(parens); 198 } 199 200 201 struct PreparedStatement { 202 package: 203 uint id; 204 uint params; 205 } 206 207 208 enum ConnectionOptions { 209 TextProtocol = 1 << 0, // Execute method uses the MySQL text protocol under the hood - it's less safe but can increase performance in some situations 210 TextProtocolCheckNoArgs = 1 << 1, // Check for orphan placeholders even if arguments are passed 211 Default = 0 212 } 213 214 215 struct Connection(SocketType, ConnectionOptions Options = ConnectionOptions.Default) { 216 void connect(string connectionString) { 217 settings_ = ConnectionSettings(connectionString); 218 connect(); 219 } 220 221 void connect(ConnectionSettings settings) { 222 settings_ = settings; 223 connect(); 224 } 225 226 void connect(const(char)[] host, ushort port, const(char)[] user, const(char)[] pwd, const(char)[] db, CapabilityFlags caps = DefaultClientCaps) { 227 settings_.host = host; 228 settings_.user = user; 229 settings_.pwd = pwd; 230 settings_.db = db; 231 settings_.port = port; 232 settings_.caps = caps | CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_PROTOCOL_41; 233 234 connect(); 235 } 236 237 void use(const(char)[] db) { 238 send(Commands.COM_INIT_DB, db); 239 eatStatus(retrieve()); 240 241 if ((caps_ & CapabilityFlags.CLIENT_SESSION_TRACK) == 0) { 242 schema_.length = db.length; 243 schema_[] = db[]; 244 } 245 } 246 247 void ping() { 248 send(Commands.COM_PING); 249 eatStatus(retrieve()); 250 } 251 252 void refresh() { 253 send(Commands.COM_REFRESH); 254 eatStatus(retrieve()); 255 } 256 257 static if ((Options & ConnectionOptions.TextProtocol) == 0) { 258 void reset() { 259 send(Commands.COM_RESET_CONNECTION); 260 eatStatus(retrieve()); 261 } 262 } 263 264 const(char)[] statistics() { 265 send(Commands.COM_STATISTICS); 266 267 auto answer = retrieve(); 268 return answer.eat!(const(char)[])(answer.remaining); 269 } 270 271 const(char)[] schema() const { 272 return schema_; 273 } 274 275 ConnectionSettings settings() const { 276 return settings_; 277 } 278 279 auto prepare(const(char)[] sql) { 280 send(Commands.COM_STMT_PREPARE, sql); 281 282 auto answer = retrieve(); 283 284 if (answer.peek!ubyte != StatusPackets.OK_Packet) 285 eatStatus(answer); 286 287 answer.expect!ubyte(0); 288 289 auto id = answer.eat!uint; 290 auto columns = answer.eat!ushort; 291 auto params = answer.eat!ushort; 292 answer.expect!ubyte(0); 293 294 auto warnings = answer.eat!ushort; 295 296 if (params) { 297 foreach (i; 0..params) 298 skipColumnDef(retrieve(), Commands.COM_STMT_PREPARE); 299 300 eatEOF(retrieve()); 301 } 302 303 if (columns) { 304 foreach (i; 0..columns) 305 skipColumnDef(retrieve(), Commands.COM_STMT_PREPARE); 306 307 eatEOF(retrieve()); 308 } 309 310 return PreparedStatement(id, params); 311 } 312 313 void execute(Args...)(const(char)[] sql, Args args) { 314 static if (Options & ConnectionOptions.TextProtocol) { 315 query(sql, args); 316 } else { 317 scope(failure) disconnect_(); 318 319 auto id = prepare(sql); 320 execute(id, args); 321 close(id); 322 } 323 } 324 325 void set(T)(const(char)[] variable, T value) { 326 query("set session ?=?", MySQLFragment(variable), value); 327 } 328 329 const(char)[] get(const(char)[] variable) { 330 const(char)[] result; 331 query("show session variables like ?", variable, (MySQLRow row) { 332 result = row[1].peek!(const(char)[]).dup; 333 }); 334 335 return result; 336 } 337 338 void begin() { 339 if (inTransaction) 340 throw new MySQLErrorException("MySQL does not support nested transactions - commit or rollback before starting a new transaction"); 341 342 query("start transaction"); 343 344 assert(inTransaction); 345 } 346 347 void commit() { 348 if (!inTransaction) 349 throw new MySQLErrorException("No active transaction"); 350 351 query("commit"); 352 353 assert(!inTransaction); 354 } 355 356 void rollback() { 357 if (connected) { 358 if ((status_.flags & StatusFlags.SERVER_STATUS_IN_TRANS) == 0) 359 throw new MySQLErrorException("No active transaction"); 360 361 query("rollback"); 362 363 assert(!inTransaction); 364 } 365 } 366 367 @property bool inTransaction() const { 368 return connected && (status_.flags & StatusFlags.SERVER_STATUS_IN_TRANS); 369 } 370 371 void execute(Args...)(PreparedStatement stmt, Args args) { 372 scope(failure) disconnect_(); 373 374 ensureConnected(); 375 376 seq_ = 0; 377 auto packet = OutputPacket(&out_); 378 packet.put!ubyte(Commands.COM_STMT_EXECUTE); 379 packet.put!uint(stmt.id); 380 packet.put!ubyte(Cursors.CURSOR_TYPE_READ_ONLY); 381 packet.put!uint(1); 382 383 static if (args.length == 0) { 384 enum shouldDiscard = true; 385 } else { 386 enum shouldDiscard = !isCallable!(args[args.length - 1]); 387 } 388 389 enum argCount = shouldDiscard ? args.length : (args.length - 1); 390 391 if (!argCount && stmt.params) 392 throw new MySQLErrorException(format("Wrong number of parameters for query. Got 0 but expected %d.", stmt.params)); 393 394 static if (argCount) { 395 enum NullsCapacity = 128; // must be power of 2 396 ubyte[NullsCapacity >> 3] nulls; 397 size_t bitsOut; 398 size_t indexArg; 399 foreach(i, arg; args[0..argCount]) { 400 const auto index = (indexArg >> 3) & (NullsCapacity - 1); 401 const auto bit = indexArg & 7; 402 403 static if (is(typeof(arg) == typeof(null))) { 404 nulls[index] = nulls[index] | (1 << bit); 405 ++indexArg; 406 } else static if (is(Unqual!(typeof(arg)) == MySQLValue)) { 407 if (arg.isNull) 408 nulls[index] = nulls[index] | (1 << bit); 409 ++indexArg; 410 } else static if (isArray!(typeof(arg)) && !isSomeString!(typeof(arg))) { 411 indexArg += arg.length; 412 } else { 413 ++indexArg; 414 } 415 416 auto finishing = (i == argCount - 1); 417 auto remaining = indexArg - bitsOut; 418 419 if (finishing || (remaining >= NullsCapacity)) { 420 while (remaining) { 421 auto bits = min(remaining, NullsCapacity); 422 423 packet.put(nulls[0..(bits + 7) >> 3]); 424 bitsOut += bits; 425 nulls[] = 0; 426 427 remaining = (indexArg - bitsOut); 428 if (!remaining || (!finishing && (remaining < NullsCapacity))) 429 break; 430 } 431 } 432 } 433 packet.put!ubyte(1); 434 435 if (indexArg != stmt.params) 436 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", indexArg, stmt.params)); 437 438 foreach (arg; args[0..argCount]) { 439 static if (is(typeof(arg) == enum)) { 440 putValueType(packet, cast(OriginalType!(Unqual!(typeof(arg))))arg); 441 } else { 442 putValueType(packet, arg); 443 } 444 } 445 446 foreach (arg; args[0..argCount]) { 447 static if (is(typeof(arg) == enum)) { 448 putValue(packet, cast(OriginalType!(Unqual!(typeof(arg))))arg); 449 } else { 450 putValue(packet, arg); 451 } 452 } 453 } 454 455 packet.finalize(seq_); 456 ++seq_; 457 458 socket_.write(packet.get()); 459 460 auto answer = retrieve(); 461 if (isStatus(answer)) { 462 eatStatus(answer); 463 } else { 464 static if (!shouldDiscard) { 465 resultSet(answer, stmt.id, Commands.COM_STMT_EXECUTE, args[args.length - 1]); 466 } else { 467 discardAll(answer, Commands.COM_STMT_EXECUTE); 468 } 469 } 470 } 471 472 void close(PreparedStatement stmt) { 473 uint[1] data = [ stmt.id ]; 474 send(Commands.COM_STMT_CLOSE, data); 475 } 476 477 alias OnStatusCallback = scope void delegate(ConnectionStatus status, const(char)[] message); 478 @property void onStatus(OnStatusCallback callback) { 479 onStatus_ = callback; 480 } 481 482 @property OnStatusCallback onStatus() const { 483 return onStatus_; 484 } 485 486 alias OnDisconnectCallback = scope void delegate(ConnectionStatus status); 487 @property void onDisconnect(OnDisconnectCallback callback) { 488 onDisconnect_ = callback; 489 } 490 491 @property OnDisconnectCallback onDisconnect() const { 492 return onDisconnect_; 493 } 494 495 @property ulong insertID() const { 496 return status_.insertID; 497 } 498 499 @property ulong affected() const { 500 return cast(size_t)status_.affected; 501 } 502 503 @property ulong matched() const { 504 return cast(size_t)status_.matched; 505 } 506 507 @property ulong changed() const { 508 return cast(size_t)status_.changed; 509 } 510 511 @property size_t warnings() const { 512 return status_.warnings; 513 } 514 515 @property size_t error() const { 516 return status_.error; 517 } 518 519 @property const(char)[] status() const { 520 return info_; 521 } 522 523 @property bool connected() const { 524 return socket_.connected; 525 } 526 527 void disconnect() { 528 socket_.close(); 529 } 530 531 void reuse() { 532 onDisconnect_ = null; 533 onStatus_ = null; 534 535 ensureConnected(); 536 537 if (inTransaction) 538 rollback; 539 if (settings_.db.length && (settings_.db != schema_)) 540 use(settings_.db); 541 } 542 543 @property void trace(bool enable) { 544 trace_ = enable; 545 } 546 547 @property bool trace() { 548 return trace_; 549 } 550 551 private: 552 void disconnect_() { 553 disconnect(); 554 if (onDisconnect_ && error) 555 onDisconnect_(status_); 556 } 557 558 void query(Args...)(const(char)[] sql, Args args) { 559 scope(failure) disconnect_(); 560 561 static if (args.length == 0) { 562 enum shouldDiscard = true; 563 } else { 564 enum shouldDiscard = !isCallable!(args[args.length - 1]); 565 } 566 567 enum argCount = shouldDiscard ? args.length : (args.length - 1); 568 569 static if (argCount || (Options & ConnectionOptions.TextProtocolCheckNoArgs)) { 570 auto querySQL = prepareSQL(sql, args[0..argCount]); 571 } else { 572 auto querySQL = sql; 573 } 574 575 version(development) { 576 import std.stdio : stderr, writefln; 577 if (trace_) 578 stderr.writefln("%s", querySQL); 579 } 580 581 send(Commands.COM_QUERY, querySQL); 582 583 auto answer = retrieve(); 584 if (isStatus(answer)) { 585 eatStatus(answer); 586 } else { 587 static if (!shouldDiscard) { 588 resultSetText(answer, Commands.COM_QUERY, args[args.length - 1]); 589 } else { 590 discardAll(answer, Commands.COM_QUERY); 591 } 592 } 593 } 594 595 void connect() { 596 socket_.connect(settings_.host, settings_.port); 597 598 seq_ = 0; 599 eatHandshake(retrieve()); 600 } 601 602 void send(T)(Commands cmd, T[] data) { 603 send(cmd, cast(ubyte*)data.ptr, data.length * T.sizeof); 604 } 605 606 void send(Commands cmd, ubyte* data = null, size_t length = 0) { 607 ensureConnected(); 608 609 seq_ = 0; 610 auto header = OutputPacket(&out_); 611 header.put!ubyte(cmd); 612 header.finalize(seq_, length); 613 ++seq_; 614 615 socket_.write(header.get()); 616 if (length) 617 socket_.write(data[0..length]); 618 } 619 620 void ensureConnected() { 621 if (!socket_.connected) 622 connect(); 623 } 624 625 bool isStatus(InputPacket packet) { 626 auto id = packet.peek!ubyte; 627 switch (id) { 628 case StatusPackets.ERR_Packet: 629 case StatusPackets.OK_Packet: 630 return true; 631 default: 632 return false; 633 } 634 } 635 636 void check(InputPacket packet, bool smallError = false) { 637 auto id = packet.peek!ubyte; 638 switch (id) { 639 case StatusPackets.ERR_Packet: 640 case StatusPackets.OK_Packet: 641 eatStatus(packet, smallError); 642 break; 643 default: 644 break; 645 } 646 } 647 648 InputPacket retrieve() { 649 scope(failure) disconnect_(); 650 651 ubyte[4] header; 652 socket_.read(header); 653 654 auto len = header.ptr[0] | (header.ptr[1] << 8) | (header.ptr[2] << 16); 655 auto seq = header.ptr[3]; 656 657 if (seq != seq_) 658 throw new MySQLConnectionException("Out of order packet received"); 659 660 ++seq_; 661 662 in_.length = len; 663 socket_.read(in_); 664 665 if (in_.length != len) 666 throw new MySQLConnectionException("Wrong number of bytes read"); 667 668 return InputPacket(&in_); 669 } 670 671 void eatHandshake(InputPacket packet) { 672 scope(failure) disconnect_(); 673 674 check(packet, true); 675 676 server_.protocol = packet.eat!ubyte; 677 server_.versionString = packet.eat!(const(char)[])(packet.countUntil(0, true)).dup; 678 packet.skip(1); 679 680 server_.connection = packet.eat!uint; 681 682 const auto authLengthStart = 8; 683 size_t authLength = authLengthStart; 684 685 ubyte[256] auth; 686 auth[0..authLength] = packet.eat!(ubyte[])(authLength); 687 688 packet.expect!ubyte(0); 689 690 server_.caps = packet.eat!ushort; 691 692 if (!packet.empty) { 693 server_.charSet = packet.eat!ubyte; 694 server_.status = packet.eat!ushort; 695 server_.caps |= packet.eat!ushort << 16; 696 server_.caps |= CapabilityFlags.CLIENT_LONG_PASSWORD; 697 698 if ((server_.caps & CapabilityFlags.CLIENT_PROTOCOL_41) == 0) 699 throw new MySQLProtocolException("Server doesn't support protocol v4.1"); 700 701 if (server_.caps & CapabilityFlags.CLIENT_SECURE_CONNECTION) { 702 packet.skip(1); 703 } else { 704 packet.expect!ubyte(0); 705 } 706 707 packet.skip(10); 708 709 authLength += packet.countUntil(0, true); 710 if (authLength > auth.length) 711 throw new MySQLConnectionException("Bad packet format"); 712 713 auth[authLengthStart..authLength] = packet.eat!(ubyte[])(authLength - authLengthStart); 714 715 packet.expect!ubyte(0); 716 } 717 718 caps_ = cast(CapabilityFlags)(settings_.caps & server_.caps); 719 720 if (((settings_.caps & CapabilityFlags.CLIENT_SSL) != 0) || settings_.ssl.enforce) { 721 if ((caps_ & CapabilityFlags.CLIENT_SSL) != 0) { 722 startSSL(); 723 } else if (settings_.ssl.enforce) { 724 throw new MySQLProtocolException("Server doesn't support SSL"); 725 } 726 } 727 728 ubyte[20] token; 729 { 730 import std.digest.sha : sha1Of; 731 732 auto pass = sha1Of(cast(const(ubyte)[])settings_.pwd); 733 token = sha1Of(auth[0..authLength], sha1Of(pass)); 734 735 foreach (i; 0..20) 736 token[i] = token[i] ^ pass[i]; 737 } 738 739 auto reply = OutputPacket(&out_); 740 741 reply.reserve(64 + settings_.user.length + settings_.pwd.length + settings_.db.length); 742 743 reply.put!uint(caps_); 744 reply.put!uint(1); 745 reply.put!ubyte(45); 746 reply.fill(0, 23); 747 748 reply.put(settings_.user); 749 reply.put!ubyte(0); 750 751 if (settings_.pwd.length) { 752 if (caps_ & CapabilityFlags.CLIENT_SECURE_CONNECTION) { 753 reply.put!ubyte(token.length); 754 reply.put(token); 755 } else { 756 reply.put(token); 757 reply.put!ubyte(0); 758 } 759 } else { 760 reply.put!ubyte(0); 761 } 762 763 if ((settings_.db.length || schema_.length) && (caps_ & CapabilityFlags.CLIENT_CONNECT_WITH_DB)) { 764 if (schema_.length) { 765 reply.put(schema_); 766 } else { 767 reply.put(settings_.db); 768 769 schema_.length = settings_.db.length; 770 schema_[] = settings_.db[]; 771 } 772 } 773 774 reply.put!ubyte(0); 775 776 reply.finalize(seq_); 777 ++seq_; 778 779 socket_.write(reply.get()); 780 781 eatStatus(retrieve()); 782 } 783 784 785 void startSSL() { 786 auto request = OutputPacket(&out_); 787 788 request.reserve(64); 789 790 request.put!uint(caps_); 791 request.put!uint(1); 792 request.put!ubyte(45); 793 request.fill(0, 23); 794 795 request.finalize(seq_); 796 ++seq_; 797 798 socket_.write(request.get()); 799 800 socket_.startSSL(settings_.host, settings_.ssl); 801 } 802 803 804 void eatStatus(InputPacket packet, bool smallError = false) { 805 auto id = packet.eat!ubyte; 806 807 switch (id) { 808 case StatusPackets.OK_Packet: 809 status_.matched = 0; 810 status_.changed = 0; 811 status_.affected = packet.eatLenEnc(); 812 status_.insertID = packet.eatLenEnc(); 813 status_.flags = packet.eat!ushort; 814 if (caps_ & CapabilityFlags.CLIENT_PROTOCOL_41) 815 status_.warnings = packet.eat!ushort; 816 status_.error = 0; 817 info([]); 818 819 if (caps_ & CapabilityFlags.CLIENT_SESSION_TRACK) { 820 if (!packet.empty) { 821 info(packet.eat!(const(char)[])(packet.eatLenEnc())); 822 823 if (status_.flags & StatusFlags.SERVER_SESSION_STATE_CHANGED) { 824 packet.skipLenEnc(); 825 while (!packet.empty()) { 826 final switch (packet.eat!ubyte()) with (SessionStateType) { 827 case SESSION_TRACK_SCHEMA: 828 packet.skipLenEnc(); 829 schema_.length = packet.eatLenEnc(); 830 schema_[] = packet.eat!(const(char)[])(schema_.length); 831 break; 832 case SESSION_TRACK_SYSTEM_VARIABLES: 833 case SESSION_TRACK_GTIDS: 834 case SESSION_TRACK_STATE_CHANGE: 835 case SESSION_TRACK_TRANSACTION_STATE: 836 case SESSION_TRACK_TRANSACTION_CHARACTERISTICS: 837 packet.skip(packet.eatLenEnc()); 838 break; 839 } 840 } 841 } 842 } 843 } else { 844 info(packet.eat!(const(char)[])(packet.remaining)); 845 } 846 847 import std.regex : matchFirst, regex; 848 static matcher = regex(`\smatched:\s*(\d+)\s+changed:\s*(\d+)`, `i`); 849 auto matches = matchFirst(info_, matcher); 850 if (!matches.empty) { 851 status_.matched = matches[1].to!ulong; 852 status_.changed = matches[2].to!ulong; 853 } 854 855 if (onStatus_) 856 onStatus_(status_, info_); 857 858 break; 859 case StatusPackets.EOF_Packet: 860 status_.affected = 0; 861 status_.changed = 0; 862 status_.matched = 0; 863 status_.error = 0; 864 status_.warnings = packet.eat!ushort; 865 status_.flags = packet.eat!ushort; 866 info([]); 867 868 if (onStatus_) 869 onStatus_(status_, info_); 870 871 break; 872 case StatusPackets.ERR_Packet: 873 status_.affected = 0; 874 status_.changed = 0; 875 status_.matched = 0; 876 status_.flags = 0; 877 status_.warnings = 0; 878 status_.error = packet.eat!ushort; 879 if (!smallError) 880 packet.skip(6); 881 info(packet.eat!(const(char)[])(packet.remaining)); 882 883 if (onStatus_) 884 onStatus_(status_, info_); 885 886 switch(status_.error) { 887 case ErrorCodes.ER_DUP_ENTRY_WITH_KEY_NAME: 888 case ErrorCodes.ER_DUP_ENTRY: 889 throw new MySQLDuplicateEntryException(info_.idup); 890 case ErrorCodes.ER_DATA_TOO_LONG_FOR_COL: 891 throw new MySQLDataTooLongException(info_.idup); 892 case ErrorCodes.ER_DEADLOCK_FOUND: 893 throw new MySQLDeadlockFoundException(info_.idup); 894 case ErrorCodes.ER_TABLE_DOESNT_EXIST: 895 throw new MySQLTableDoesntExistException(info_.idup); 896 case ErrorCodes.ER_LOCK_WAIT_TIMEOUT: 897 throw new MySQLLockWaitTimeoutException(info_.idup); 898 default: 899 version(development) { 900 // On dev show the query together with the error message 901 throw new MySQLErrorException(format("[err:%s] %s - %s", status_.error, info_, sql_.data)); 902 } else { 903 throw new MySQLErrorException(format("[err:%s] %s", status_.error, info_)); 904 } 905 } 906 default: 907 throw new MySQLProtocolException("Unexpected packet format"); 908 } 909 } 910 911 void info(const(char)[] value) { 912 info_.length = value.length; 913 info_[0..$] = value; 914 } 915 916 void skipColumnDef(InputPacket packet, Commands cmd) { 917 packet.skip(cast(size_t)packet.eatLenEnc()); // catalog 918 packet.skip(cast(size_t)packet.eatLenEnc()); // schema 919 packet.skip(cast(size_t)packet.eatLenEnc()); // table 920 packet.skip(cast(size_t)packet.eatLenEnc()); // original_table 921 packet.skip(cast(size_t)packet.eatLenEnc()); // name 922 packet.skip(cast(size_t)packet.eatLenEnc()); // original_name 923 packet.skipLenEnc(); // next_length 924 packet.skip(10); // 2 + 4 + 1 + 2 + 1 // charset, length, type, flags, decimals 925 packet.expect!ushort(0); 926 927 if (cmd == Commands.COM_FIELD_LIST) 928 packet.skip(cast(size_t)packet.eatLenEnc());// default values 929 } 930 931 void columnDef(InputPacket packet, Commands cmd, ref MySQLColumn def) { 932 packet.skip(cast(size_t)packet.eatLenEnc()); // catalog 933 packet.skip(cast(size_t)packet.eatLenEnc()); // schema 934 packet.skip(cast(size_t)packet.eatLenEnc()); // table 935 packet.skip(cast(size_t)packet.eatLenEnc()); // original_table 936 auto len = cast(size_t)packet.eatLenEnc(); 937 columns_ ~= packet.eat!(const(char)[])(len); 938 def.name = columns_[$-len..$]; 939 packet.skip(cast(size_t)packet.eatLenEnc()); // original_name 940 packet.skipLenEnc(); // next_length 941 packet.skip(2); // charset 942 def.length = packet.eat!uint; 943 def.type = cast(ColumnTypes)packet.eat!ubyte; 944 def.flags = packet.eat!ushort; 945 def.decimals = packet.eat!ubyte; 946 947 packet.expect!ushort(0); 948 949 if (cmd == Commands.COM_FIELD_LIST) 950 packet.skip(cast(size_t)packet.eatLenEnc());// default values 951 } 952 953 void columnDefs(size_t count, Commands cmd, ref MySQLColumn[] defs) { 954 defs.length = count; 955 foreach (i; 0..count) 956 columnDef(retrieve(), cmd, defs[i]); 957 } 958 959 bool callHandler(RowHandler)(RowHandler handler, size_t, MySQLHeader, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 1) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLRow)) { 960 static if (is(ReturnType!(RowHandler) == void)) { 961 handler(row); 962 return true; 963 } else { 964 return handler(row); // return type must be bool 965 } 966 } 967 968 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) { 969 static if (is(ReturnType!(RowHandler) == void)) { 970 handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); 971 return true; 972 } else { 973 return handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); // return type must be bool 974 } 975 } 976 977 bool callHandler(RowHandler)(RowHandler handler, size_t, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) { 978 static if (is(ReturnType!(RowHandler) == void)) { 979 handler(header, row); 980 return true; 981 } else { 982 return handler(header, row); // return type must be bool 983 } 984 } 985 986 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 3) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[2] == MySQLRow)) { 987 static if (is(ReturnType!(RowHandler) == void)) { 988 handler(i, header, row); 989 return true; 990 } else { 991 return handler(i, header, row); // return type must be bool 992 } 993 } 994 995 void resultSetRow(InputPacket packet, MySQLHeader header, ref MySQLRow row) { 996 assert(row.columns.length == header.length); 997 998 packet.expect!ubyte(0); 999 auto nulls = packet.eat!(ubyte[])((header.length + 2 + 7) >> 3); 1000 foreach (i, ref column; header) { 1001 const auto index = (i + 2) >> 3; // bit offset of 2 1002 const auto bit = (i + 2) & 7; 1003 1004 if ((nulls[index] & (1 << bit)) == 0) { 1005 eatValue(packet, column, row.get_(i)); 1006 } else { 1007 auto signed = (column.flags & FieldFlags.UNSIGNED_FLAG) == 0; 1008 row.get_(i) = MySQLValue(column.name, ColumnTypes.MYSQL_TYPE_NULL, signed, null, 0); 1009 } 1010 } 1011 assert(packet.empty); 1012 } 1013 1014 void resultSet(RowHandler)(InputPacket packet, uint stmt, Commands cmd, RowHandler handler) { 1015 columns_.length = 0; 1016 1017 auto columns = cast(size_t)packet.eatLenEnc(); 1018 columnDefs(columns, cmd, header_); 1019 row_.header_(header_); 1020 1021 auto status = retrieve(); 1022 if (status.peek!ubyte == StatusPackets.ERR_Packet) 1023 eatStatus(status); 1024 1025 size_t index; 1026 auto statusFlags = eatEOF(status); 1027 if (statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) { 1028 uint[2] data = [ stmt, 4096 ]; // todo: make setting - rows per fetch 1029 while (statusFlags & (StatusFlags.SERVER_STATUS_CURSOR_EXISTS | StatusFlags.SERVER_MORE_RESULTS_EXISTS)) { 1030 send(Commands.COM_STMT_FETCH, data); 1031 1032 auto answer = retrieve(); 1033 if (answer.peek!ubyte == StatusPackets.ERR_Packet) 1034 eatStatus(answer); 1035 1036 auto row = answer.empty ? retrieve() : answer; 1037 while (true) { 1038 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 1039 statusFlags = eatEOF(row); 1040 break; 1041 } 1042 1043 resultSetRow(row, header_, row_); 1044 if (!callHandler(handler, index++, header_, row_)) { 1045 discardUntilEOF(retrieve()); 1046 statusFlags = 0; 1047 break; 1048 } 1049 row = retrieve(); 1050 } 1051 } 1052 } else { 1053 while (true) { 1054 auto row = retrieve(); 1055 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 1056 eatEOF(row); 1057 break; 1058 } 1059 1060 resultSetRow(row, header_, row_); 1061 if (!callHandler(handler, index++, header_, row_)) { 1062 discardUntilEOF(retrieve()); 1063 break; 1064 } 1065 } 1066 } 1067 } 1068 1069 void resultSetRowText(InputPacket packet, MySQLHeader header, ref MySQLRow row) { 1070 assert(row.columns.length == header.length); 1071 1072 foreach(i, ref column; header) { 1073 if (packet.peek!ubyte != 0xfb) { 1074 eatValueText(packet, column, row.get_(i)); 1075 } else { 1076 packet.skip(1); 1077 auto signed = (column.flags & FieldFlags.UNSIGNED_FLAG) == 0; 1078 row.get_(i) = MySQLValue(column.name, ColumnTypes.MYSQL_TYPE_NULL, signed, null, 0); 1079 } 1080 } 1081 assert(packet.empty); 1082 } 1083 1084 void resultSetText(RowHandler)(InputPacket packet, Commands cmd, RowHandler handler) { 1085 columns_.length = 0; 1086 1087 auto columns = cast(size_t)packet.eatLenEnc(); 1088 columnDefs(columns, cmd, header_); 1089 row_.header_(header_); 1090 1091 eatEOF(retrieve()); 1092 1093 size_t index; 1094 while (true) { 1095 auto row = retrieve(); 1096 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 1097 eatEOF(row); 1098 break; 1099 } else if (row.peek!ubyte == StatusPackets.ERR_Packet) { 1100 eatStatus(row); 1101 break; 1102 } 1103 1104 resultSetRowText(row, header_, row_); 1105 if (!callHandler(handler, index++, header_, row_)) { 1106 discardUntilEOF(retrieve()); 1107 break; 1108 } 1109 } 1110 } 1111 1112 void discardAll(InputPacket packet, Commands cmd) { 1113 auto columns = cast(size_t)packet.eatLenEnc(); 1114 columnDefs(columns, cmd, header_); 1115 1116 auto statusFlags = eatEOF(retrieve()); 1117 if ((statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) == 0) { 1118 while (true) { 1119 auto row = retrieve(); 1120 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 1121 eatEOF(row); 1122 break; 1123 } 1124 } 1125 } 1126 } 1127 1128 void discardUntilEOF(InputPacket packet) { 1129 while (true) { 1130 if (packet.peek!ubyte == StatusPackets.EOF_Packet) { 1131 eatEOF(packet); 1132 break; 1133 } 1134 packet = retrieve(); 1135 } 1136 } 1137 1138 auto eatEOF(InputPacket packet) { 1139 auto id = packet.eat!ubyte; 1140 if (id != StatusPackets.EOF_Packet) 1141 throw new MySQLProtocolException("Unexpected packet format"); 1142 1143 status_.error = 0; 1144 status_.warnings = packet.eat!ushort(); 1145 status_.flags = packet.eat!ushort(); 1146 info([]); 1147 1148 if (onStatus_) 1149 onStatus_(status_, info_); 1150 1151 return status_.flags; 1152 } 1153 1154 auto estimateArgs(Args...)(ref size_t estimated, Args args) { 1155 size_t argCount; 1156 1157 foreach(i, arg; args) { 1158 static if (is(typeof(arg) == typeof(null))) { 1159 ++argCount; 1160 estimated += 4; 1161 } else static if (is(Unqual!(typeof(arg)) == MySQLValue)) { 1162 ++argCount; 1163 final switch(arg.type) with (ColumnTypes) { 1164 case MYSQL_TYPE_NULL: 1165 estimated += 4; 1166 break; 1167 case MYSQL_TYPE_TINY: 1168 estimated += 4; 1169 break; 1170 case MYSQL_TYPE_YEAR: 1171 case MYSQL_TYPE_SHORT: 1172 estimated += 6; 1173 break; 1174 case MYSQL_TYPE_INT24: 1175 case MYSQL_TYPE_LONG: 1176 estimated += 6; 1177 break; 1178 case MYSQL_TYPE_LONGLONG: 1179 estimated += 8; 1180 break; 1181 case MYSQL_TYPE_FLOAT: 1182 estimated += 8; 1183 break; 1184 case MYSQL_TYPE_DOUBLE: 1185 estimated += 8; 1186 break; 1187 case MYSQL_TYPE_SET: 1188 case MYSQL_TYPE_ENUM: 1189 case MYSQL_TYPE_VARCHAR: 1190 case MYSQL_TYPE_VAR_STRING: 1191 case MYSQL_TYPE_STRING: 1192 case MYSQL_TYPE_JSON: 1193 case MYSQL_TYPE_NEWDECIMAL: 1194 case MYSQL_TYPE_DECIMAL: 1195 case MYSQL_TYPE_TINY_BLOB: 1196 case MYSQL_TYPE_MEDIUM_BLOB: 1197 case MYSQL_TYPE_LONG_BLOB: 1198 case MYSQL_TYPE_BLOB: 1199 case MYSQL_TYPE_BIT: 1200 case MYSQL_TYPE_GEOMETRY: 1201 estimated += 2 + arg.peek!(const(char)[]).length; 1202 break; 1203 case MYSQL_TYPE_TIME: 1204 case MYSQL_TYPE_TIME2: 1205 estimated += 18; 1206 break; 1207 case MYSQL_TYPE_DATE: 1208 case MYSQL_TYPE_NEWDATE: 1209 case MYSQL_TYPE_DATETIME: 1210 case MYSQL_TYPE_DATETIME2: 1211 case MYSQL_TYPE_TIMESTAMP: 1212 case MYSQL_TYPE_TIMESTAMP2: 1213 estimated += 20; 1214 break; 1215 } 1216 } else static if (isArray!(typeof(arg)) && !isSomeString!(typeof(arg))) { 1217 argCount += arg.length; 1218 estimated += arg.length * 6; 1219 } else static if (isSomeString!(typeof(arg)) || is(Unqual!(typeof(arg)) == MySQLRawString) || is(Unqual!(typeof(arg)) == MySQLFragment) || is(Unqual!(typeof(arg)) == MySQLBinary)) { 1220 ++argCount; 1221 estimated += 2 + arg.length; 1222 } else { 1223 ++argCount; 1224 estimated += 6; 1225 } 1226 } 1227 return argCount; 1228 } 1229 1230 auto prepareSQL(Args...)(const(char)[] sql, Args args) { 1231 auto estimated = sql.length; 1232 auto argCount = estimateArgs(estimated, args); 1233 1234 sql_.clear; 1235 sql_.reserve(max(8192, estimated)); 1236 1237 alias AppendFunc = bool function(ref Appender!(char[]), ref const(char)[] sql, ref size_t, const(void)*) @safe pure nothrow; 1238 AppendFunc[Args.length] funcs; 1239 const(void)*[Args.length] addrs; 1240 1241 foreach (i, Arg; Args) { 1242 static if (is(Arg == enum)) { 1243 funcs[i] = () @trusted { return cast(AppendFunc)&appendNextValue!(OriginalType!Arg); }(); 1244 addrs[i] = (ref x) @trusted { return cast(const void*)&x; }(cast(OriginalType!(Unqual!Arg))args[i]); 1245 } else { 1246 funcs[i] = () @trusted { return cast(AppendFunc)&appendNextValue!(Arg); }(); 1247 addrs[i] = (ref x) @trusted { return cast(const void*)&x; }(args[i]); 1248 } 1249 } 1250 1251 size_t indexArg; 1252 foreach (i; 0..Args.length) { 1253 if (!funcs[i](sql_, sql, indexArg, addrs[i])) 1254 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", argCount, indexArg)); 1255 } 1256 1257 finishCopy(sql_, sql, argCount, indexArg); 1258 1259 return sql_.data; 1260 } 1261 1262 void finishCopy(ref Appender!(char[]) app, ref const(char)[] sql, size_t argCount, size_t indexArg) { 1263 if (copyUpToNext(sql_, sql)) { 1264 ++indexArg; 1265 while (copyUpToNext(sql_, sql)) 1266 ++indexArg; 1267 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", argCount, indexArg)); 1268 } 1269 } 1270 1271 SocketType socket_; 1272 MySQLHeader header_; 1273 MySQLRow row_; 1274 char[] columns_; 1275 char[] info_; 1276 char[] schema_; 1277 ubyte[] in_; 1278 ubyte[] out_; 1279 ubyte seq_; 1280 Appender!(char[]) sql_; 1281 1282 OnStatusCallback onStatus_; 1283 OnDisconnectCallback onDisconnect_; 1284 CapabilityFlags caps_; 1285 ConnectionStatus status_; 1286 ConnectionSettings settings_; 1287 ServerInfo server_; 1288 1289 // For tracing queries 1290 bool trace_; 1291 } 1292 1293 private auto copyUpToNext(ref Appender!(char[]) app, ref const(char)[] sql) { 1294 size_t offset; 1295 dchar quote = '\0'; 1296 1297 while (offset < sql.length) { 1298 auto ch = decode!(UseReplacementDchar.no)(sql, offset); 1299 switch (ch) { 1300 case '?': 1301 if (!quote) { 1302 app.put(sql[0..offset - 1]); 1303 sql = sql[offset..$]; 1304 return true; 1305 } else { 1306 goto default; 1307 } 1308 case '\'': 1309 case '\"': 1310 case '`': 1311 if (quote == ch) { 1312 quote = '\0'; 1313 } else if (!quote) { 1314 quote = ch; 1315 } 1316 goto default; 1317 case '\\': 1318 if (quote && (offset < sql.length)) 1319 decode!(UseReplacementDchar.no)(sql, offset); 1320 goto default; 1321 default: 1322 break; 1323 } 1324 } 1325 app.put(sql[0..offset]); 1326 sql = sql[offset..$]; 1327 return false; 1328 } 1329 1330 private bool appendNextValue(T)(ref Appender!(char[]) app, ref const(char)[] sql, ref size_t indexArg, const(void)* arg) { 1331 static if (isArray!T && !isSomeString!(OriginalType!T)) { 1332 foreach (i, ref v; *cast(T*)arg) { 1333 if (copyUpToNext(app, sql)) { 1334 appendValue(app, v); 1335 ++indexArg; 1336 } else { 1337 return false; 1338 } 1339 } 1340 } else { 1341 if (copyUpToNext(app, sql)) { 1342 appendValue(app, *cast(T*)arg); 1343 ++indexArg; 1344 } else { 1345 return false; 1346 } 1347 } 1348 return true; 1349 }