1 module mysql.connection; 2 3 4 import std.algorithm; 5 import std.array; 6 import std.conv : to; 7 import std.regex : ctRegex, matchFirst; 8 import std.string; 9 import std.traits; 10 import std.utf : decode, UseReplacementDchar; 11 12 import mysql.appender; 13 public import mysql.exception; 14 import mysql.packet; 15 import mysql.protocol; 16 public import mysql.type; 17 18 19 immutable CapabilityFlags DefaultClientCaps = CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_LONG_FLAG | 20 CapabilityFlags.CLIENT_CONNECT_WITH_DB | CapabilityFlags.CLIENT_PROTOCOL_41 | CapabilityFlags.CLIENT_SECURE_CONNECTION | CapabilityFlags.CLIENT_SESSION_TRACK; 21 22 23 struct ConnectionStatus { 24 ulong affected; 25 ulong matched; 26 ulong changed; 27 ulong insertID; 28 ushort flags; 29 ushort error; 30 ushort warnings; 31 } 32 33 34 struct ConnectionSettings { 35 this(const(char)[] connectionString) { 36 auto remaining = connectionString; 37 38 auto indexValue = remaining.indexOf("="); 39 while (!remaining.empty) { 40 auto indexValueEnd = remaining.indexOf(";", indexValue); 41 if (indexValueEnd <= 0) 42 indexValueEnd = remaining.length; 43 44 auto name = strip(remaining[0..indexValue]); 45 auto value = strip(remaining[indexValue+1..indexValueEnd]); 46 47 switch (name) { 48 case "host": 49 host = value; 50 break; 51 case "user": 52 user = value; 53 break; 54 case "pwd": 55 pwd = value; 56 break; 57 case "db": 58 db = value; 59 break; 60 case "port": 61 port = to!ushort(value); 62 break; 63 default: 64 throw new MySQLException("Bad connection string: " ~ cast(string)connectionString); 65 } 66 67 if (indexValueEnd == remaining.length) 68 return; 69 70 remaining = remaining[indexValueEnd+1..$]; 71 indexValue = remaining.indexOf("="); 72 } 73 74 throw new MySQLException("Bad connection string: " ~ cast(string)connectionString); 75 } 76 77 CapabilityFlags caps = DefaultClientCaps; 78 79 const(char)[] host; 80 const(char)[] user; 81 const(char)[] pwd; 82 const(char)[] db; 83 ushort port = 3306; 84 } 85 86 87 private struct ServerInfo { 88 const(char)[] versionString; 89 ubyte protocol; 90 ubyte charSet; 91 ushort status; 92 uint connection; 93 uint caps; 94 } 95 96 97 @property string placeholders(size_t x, bool parens = true) { 98 if (x) { 99 auto app = appender!string; 100 if (parens) { 101 app.reserve(x + x - 1); 102 103 app.put('('); 104 foreach (i; 0..x - 1) 105 app.put("?,"); 106 app.put('?'); 107 app.put(')'); 108 } else { 109 app.reserve(x + x + 1); 110 111 foreach (i; 0..x - 1) 112 app.put("?,"); 113 app.put('?'); 114 } 115 return app.data; 116 } 117 118 return null; 119 } 120 121 122 @property string placeholders(T)(T[] x, bool parens = true) { 123 return x.length.placeholders; 124 } 125 126 127 struct PreparedStatement { 128 package: 129 uint id; 130 uint params; 131 } 132 133 134 enum ConnectionOptions { 135 TextProtocol = 1 << 0, // Execute method uses the MySQL text protocol under the hood - it's less safe but can increase performance in some situations 136 TextProtocolCheckNoArgs = 1 << 1, // Check for orphan placeholders even if arguments are passed 137 Default = 0 138 } 139 140 141 struct Connection(SocketType, ConnectionOptions Options = ConnectionOptions.Default) { 142 void connect(string connectionString) { 143 settings_ = ConnectionSettings(connectionString); 144 connect(); 145 } 146 147 void connect(ConnectionSettings settings) { 148 settings_ = settings; 149 connect(); 150 } 151 152 void connect(const(char)[] host, ushort port, const(char)[] user, const(char)[] pwd, const(char)[] db, CapabilityFlags caps = DefaultClientCaps) { 153 settings_.host = host; 154 settings_.user = user; 155 settings_.pwd = pwd; 156 settings_.db = db; 157 settings_.port = port; 158 settings_.caps = caps | CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_PROTOCOL_41; 159 160 connect(); 161 } 162 163 void use(const(char)[] db) { 164 send(Commands.COM_INIT_DB, db); 165 eatStatus(retrieve()); 166 167 if ((caps_ & CapabilityFlags.CLIENT_SESSION_TRACK) == 0) { 168 schema_.length = db.length; 169 schema_[] = db[]; 170 } 171 } 172 173 void ping() { 174 send(Commands.COM_PING); 175 eatStatus(retrieve()); 176 } 177 178 void refresh() { 179 send(Commands.COM_REFRESH); 180 eatStatus(retrieve()); 181 } 182 183 void reset() { 184 send(Commands.COM_RESET_CONNECTION); 185 eatStatus(retrieve()); 186 } 187 188 const(char)[] statistics() { 189 send(Commands.COM_STATISTICS); 190 191 auto answer = retrieve(); 192 return answer.eat!(const(char)[])(answer.remaining); 193 } 194 195 const(char)[] schema() const { 196 return schema_; 197 } 198 199 ConnectionSettings settings() const { 200 return settings_; 201 } 202 203 auto prepare(const(char)[] sql) { 204 send(Commands.COM_STMT_PREPARE, sql); 205 206 auto answer = retrieve(); 207 208 if (answer.peek!ubyte != StatusPackets.OK_Packet) 209 eatStatus(answer); 210 211 answer.expect!ubyte(0); 212 213 auto id = answer.eat!uint; 214 auto columns = answer.eat!ushort; 215 auto params = answer.eat!ushort; 216 answer.expect!ubyte(0); 217 218 auto warnings = answer.eat!ushort; 219 220 if (params) { 221 foreach (i; 0..params) 222 skipColumnDef(retrieve(), Commands.COM_STMT_PREPARE); 223 224 eatEOF(retrieve()); 225 } 226 227 if (columns) { 228 MySQLColumn def; 229 foreach (i; 0..columns) 230 skipColumnDef(retrieve(), Commands.COM_STMT_PREPARE); 231 232 eatEOF(retrieve()); 233 } 234 235 return PreparedStatement(id, params); 236 } 237 238 void execute(Args...)(const(char)[] sql, Args args) { 239 static if (Options & ConnectionOptions.TextProtocol) { 240 query(sql, args); 241 } else { 242 scope(failure) disconnect(); 243 244 auto id = prepare(sql); 245 execute(id, args); 246 close(id); 247 } 248 } 249 250 void set(T)(const(char)[] variable, T value) { 251 query("set session ?=?", MySQLFragment(variable), value); 252 } 253 254 const(char)[] get(const(char)[] variable) { 255 const(char)[] result; 256 query("show session variables like ?", variable, (MySQLRow row) { 257 result = row[1].peek!(const(char)[]).dup; 258 }); 259 260 return result; 261 } 262 263 void begin() { 264 if (inTransaction) 265 throw new MySQLErrorException("MySQL does not support nested transactions - commit or rollback before starting a new transaction"); 266 267 query("start transaction"); 268 269 assert(inTransaction); 270 } 271 272 void commit() { 273 if (!inTransaction) 274 throw new MySQLErrorException("No active transaction"); 275 276 query("commit"); 277 278 assert(!inTransaction); 279 } 280 281 void rollback() { 282 if (connected) { 283 if ((status_.flags & StatusFlags.SERVER_STATUS_IN_TRANS) == 0) 284 throw new MySQLErrorException("No active transaction"); 285 286 query("rollback"); 287 288 assert(!inTransaction); 289 } 290 } 291 292 @property bool inTransaction() const { 293 return connected && (status_.flags & StatusFlags.SERVER_STATUS_IN_TRANS); 294 } 295 296 void execute(Args...)(PreparedStatement stmt, Args args) { 297 scope(failure) disconnect(); 298 299 ensureConnected(); 300 301 seq_ = 0; 302 auto packet = OutputPacket(&out_); 303 packet.put!ubyte(Commands.COM_STMT_EXECUTE); 304 packet.put!uint(stmt.id); 305 packet.put!ubyte(Cursors.CURSOR_TYPE_READ_ONLY); 306 packet.put!uint(1); 307 308 static if (args.length == 0) { 309 enum shouldDiscard = true; 310 } else { 311 enum shouldDiscard = !isCallable!(args[args.length - 1]); 312 } 313 314 enum argCount = shouldDiscard ? args.length : (args.length - 1); 315 316 if (!argCount && stmt.params) 317 throw new MySQLErrorException(format("Wrong number of parameters for query. Got 0 but expected %d.", stmt.params)); 318 319 static if (argCount) { 320 enum NullsCapacity = 128; // must be power of 2 321 ubyte[NullsCapacity >> 3] nulls; 322 size_t bitsOut = 0; 323 size_t indexArg = 0; 324 foreach(i, arg; args[0..argCount]) { 325 const auto index = (indexArg >> 3) & (NullsCapacity - 1); 326 const auto bit = indexArg & 7; 327 328 static if (is(typeof(arg) == typeof(null))) { 329 nulls[index] = nulls[index] | (1 << bit); 330 ++indexArg; 331 } else static if (is(Unqual!(typeof(arg)) == MySQLValue)) { 332 if (arg.isNull) 333 nulls[index] = nulls[index] | (1 << bit); 334 ++indexArg; 335 } else static if (isArray!(typeof(arg)) && !isSomeString!(typeof(arg))) { 336 indexArg += arg.length; 337 } else { 338 ++indexArg; 339 } 340 341 auto finishing = (i == argCount - 1); 342 auto remaining = indexArg - bitsOut; 343 344 if (finishing || (remaining >= NullsCapacity)) { 345 while (remaining) { 346 auto bits = min(remaining, NullsCapacity); 347 348 packet.put(nulls[0..(bits + 7) >> 3]); 349 bitsOut += bits; 350 nulls[] = 0; 351 352 remaining = (indexArg - bitsOut); 353 if (!remaining || (!finishing && (remaining < NullsCapacity))) 354 break; 355 } 356 } 357 } 358 packet.put!ubyte(1); 359 360 if (indexArg != stmt.params) 361 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", indexArg, stmt.params)); 362 363 foreach (arg; args[0..argCount]) 364 putValueType(packet, arg); 365 366 foreach (arg; args[0..argCount]) { 367 static if (!is(typeof(arg) == typeof(null))) { 368 putValue(packet, arg); 369 } 370 } 371 } 372 373 packet.finalize(seq_); 374 ++seq_; 375 376 socket_.write(packet.get()); 377 378 auto answer = retrieve(); 379 if (isStatus(answer)) { 380 eatStatus(answer); 381 } else { 382 static if (!shouldDiscard) { 383 resultSet(answer, stmt.id, Commands.COM_STMT_EXECUTE, args[args.length - 1]); 384 } else { 385 discardAll(answer, Commands.COM_STMT_EXECUTE); 386 } 387 } 388 } 389 390 void close(PreparedStatement stmt) { 391 uint[1] data = [ stmt.id ]; 392 send(Commands.COM_STMT_CLOSE, data); 393 } 394 395 alias OnStatusCallback = void delegate(ConnectionStatus status, const(char)[] message); 396 @property void onStatus(OnStatusCallback callback) { 397 onStatus_ = callback; 398 } 399 400 @property OnStatusCallback onStatus() const { 401 return onStatus_; 402 } 403 404 @property ulong insertID() const { 405 return status_.insertID; 406 } 407 408 @property ulong affected() const { 409 return cast(size_t)status_.affected; 410 } 411 412 @property ulong matched() const { 413 return cast(size_t)status_.matched; 414 } 415 416 @property ulong changed() const { 417 return cast(size_t)status_.changed; 418 } 419 420 @property size_t warnings() const { 421 return status_.warnings; 422 } 423 424 @property size_t error() const { 425 return status_.error; 426 } 427 428 @property const(char)[] status() const { 429 return info_; 430 } 431 432 @property bool connected() const { 433 return socket_.connected; 434 } 435 436 void disconnect() { 437 socket_.close(); 438 } 439 440 private: 441 void query(Args...)(const(char)[] sql, Args args) { 442 scope(failure) disconnect(); 443 444 static if (args.length == 0) { 445 enum shouldDiscard = true; 446 } else { 447 enum shouldDiscard = !isCallable!(args[args.length - 1]); 448 } 449 450 enum argCount = shouldDiscard ? args.length : (args.length - 1); 451 452 static if (argCount || (Options & ConnectionOptions.TextProtocolCheckNoArgs)) { 453 send(Commands.COM_QUERY, prepareSQL(sql, args[0..argCount])); 454 } else { 455 send(Commands.COM_QUERY, sql); 456 } 457 458 auto answer = retrieve(); 459 if (isStatus(answer)) { 460 eatStatus(answer); 461 } else { 462 static if (!shouldDiscard) { 463 resultSetText(answer, Commands.COM_QUERY, args[args.length - 1]); 464 } else { 465 discardAll(answer, Commands.COM_QUERY); 466 } 467 } 468 } 469 470 void connect() { 471 socket_.connect(settings_.host, settings_.port); 472 473 seq_ = 0; 474 eatHandshake(retrieve()); 475 } 476 477 void send(T)(Commands cmd, T[] data) { 478 send(cmd, cast(ubyte*)data.ptr, data.length * T.sizeof); 479 } 480 481 void send(Commands cmd, ubyte* data = null, size_t length = 0) { 482 if(!socket_.connected) 483 connect(); 484 485 seq_ = 0; 486 auto header = OutputPacket(&out_); 487 header.put!ubyte(cmd); 488 header.finalize(seq_, length); 489 ++seq_; 490 491 socket_.write(header.get()); 492 if (length) 493 socket_.write(data[0..length]); 494 } 495 496 void ensureConnected() { 497 if(!socket_.connected) 498 connect(); 499 } 500 501 bool isStatus(InputPacket packet) { 502 auto id = packet.peek!ubyte; 503 switch (id) { 504 case StatusPackets.ERR_Packet: 505 case StatusPackets.OK_Packet: 506 return 1; 507 default: 508 return false; 509 } 510 } 511 512 void check(InputPacket packet, bool smallError = false) { 513 auto id = packet.peek!ubyte; 514 switch (id) { 515 case StatusPackets.ERR_Packet: 516 case StatusPackets.OK_Packet: 517 eatStatus(packet, smallError); 518 break; 519 default: 520 break; 521 } 522 } 523 524 InputPacket retrieve() { 525 scope(failure) disconnect(); 526 527 ubyte[4] header; 528 socket_.read(header); 529 530 auto len = header[0] | (header[1] << 8) | (header[2] << 16); 531 auto seq = header[3]; 532 533 if (seq != seq_) 534 throw new MySQLConnectionException("Out of order packet received"); 535 536 ++seq_; 537 538 in_.length = len; 539 socket_.read(in_); 540 541 if (in_.length != len) 542 throw new MySQLConnectionException("Wrong number of bytes read"); 543 544 return InputPacket(&in_); 545 } 546 547 void eatHandshake(InputPacket packet) { 548 scope(failure) disconnect(); 549 550 check(packet, true); 551 552 server_.protocol = packet.eat!ubyte; 553 server_.versionString = packet.eat!(const(char)[])(packet.countUntil(0, true)); 554 packet.skip(1); 555 556 server_.connection = packet.eat!uint; 557 558 const auto authLengthStart = 8; 559 size_t authLength = authLengthStart; 560 561 ubyte[256] auth; 562 auth[0..authLength] = packet.eat!(ubyte[])(authLength); 563 564 packet.expect!ubyte(0); 565 566 server_.caps = packet.eat!ushort; 567 568 if (!packet.empty) { 569 server_.charSet = packet.eat!ubyte; 570 server_.status = packet.eat!ushort; 571 server_.caps |= packet.eat!ushort << 16; 572 server_.caps |= CapabilityFlags.CLIENT_LONG_PASSWORD; 573 574 if ((server_.caps & CapabilityFlags.CLIENT_PROTOCOL_41) == 0) 575 throw new MySQLProtocolException("Server doesn't support protocol v4.1"); 576 577 if (server_.caps & CapabilityFlags.CLIENT_SECURE_CONNECTION) { 578 packet.skip(1); 579 } else { 580 packet.expect!ubyte(0); 581 } 582 583 packet.skip(10); 584 585 authLength += packet.countUntil(0, true); 586 if (authLength > auth.length) 587 throw new MySQLConnectionException("Bad packet format"); 588 589 auth[authLengthStart..authLength] = packet.eat!(ubyte[])(authLength - authLengthStart); 590 591 packet.expect!ubyte(0); 592 } 593 594 ubyte[20] token; 595 { 596 import std.digest.sha; 597 598 auto pass = sha1Of(cast(const(ubyte)[])settings_.pwd); 599 token = sha1Of(pass); 600 601 SHA1 sha1; 602 sha1.start(); 603 sha1.put(auth[0..authLength]); 604 sha1.put(token); 605 token = sha1.finish(); 606 607 foreach (i; 0..20) 608 token[i] = token[i] ^ pass[i]; 609 } 610 611 caps_ = cast(CapabilityFlags)(settings_.caps & server_.caps); 612 613 auto reply = OutputPacket(&out_); 614 reply.reserve(64 + settings_.user.length + settings_.pwd.length + settings_.db.length); 615 616 reply.put!uint(caps_); 617 reply.put!uint(1); 618 reply.put!ubyte(45); 619 reply.fill(0, 23); 620 621 reply.put(settings_.user); 622 reply.put!ubyte(0); 623 624 if (settings_.pwd.length) { 625 if (caps_ & CapabilityFlags.CLIENT_SECURE_CONNECTION) { 626 reply.put!ubyte(token.length); 627 reply.put(token); 628 } else { 629 reply.put(token); 630 reply.put!ubyte(0); 631 } 632 } else { 633 reply.put!ubyte(0); 634 } 635 636 if (settings_.db.length && (caps_ & CapabilityFlags.CLIENT_CONNECT_WITH_DB)) { 637 reply.put(settings_.db); 638 639 schema_.length = settings_.db.length; 640 schema_[] = settings_.db[]; 641 } 642 643 reply.put!ubyte(0); 644 645 reply.finalize(seq_); 646 ++seq_; 647 648 socket_.write(reply.get()); 649 650 eatStatus(retrieve()); 651 } 652 653 void eatStatus(InputPacket packet, bool smallError = false) { 654 auto id = packet.eat!ubyte; 655 656 switch (id) { 657 case StatusPackets.OK_Packet: 658 status_.matched = 0; 659 status_.changed = 0; 660 status_.affected = packet.eatLenEnc(); 661 status_.insertID = packet.eatLenEnc(); 662 status_.flags = packet.eat!ushort; 663 status_.warnings = packet.eat!ushort; 664 status_.error = 0; 665 666 if (!packet.empty && (caps_ & CapabilityFlags.CLIENT_SESSION_TRACK)) { 667 info(packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc())); 668 packet.skip(1); 669 670 if (status_.flags & StatusFlags.SERVER_SESSION_STATE_CHANGED) { 671 packet.skip(cast(size_t)packet.eatLenEnc()); 672 packet.skip(1); 673 } 674 } 675 676 if (!packet.empty) { 677 auto len = cast(size_t)packet.eatLenEnc(); 678 info(packet.eat!(const(char)[])(min(len, packet.remaining))); 679 680 auto matches = matchFirst(info_, ctRegex!(`\smatched:\s*(\d+)\s+changed:\s*(\d+)`, `i`)); 681 if (!matches.empty) { 682 status_.matched = matches[1].to!ulong; 683 status_.changed = matches[2].to!ulong; 684 } 685 } 686 687 if (onStatus_) 688 onStatus_(status_, info_); 689 690 break; 691 case StatusPackets.EOF_Packet: 692 status_.affected = 0; 693 status_.changed = 0; 694 status_.matched = 0; 695 status_.error = 0; 696 status_.warnings = packet.eat!ushort; 697 status_.flags = packet.eat!ushort; 698 info([]); 699 700 if (onStatus_) 701 onStatus_(status_, info_); 702 703 break; 704 case StatusPackets.ERR_Packet: 705 status_.affected = 0; 706 status_.changed = 0; 707 status_.matched = 0; 708 status_.flags = 0; 709 status_.warnings = 0; 710 status_.error = packet.eat!ushort; 711 if (!smallError) 712 packet.skip(6); 713 info(packet.eat!(const(char)[])(packet.remaining)); 714 715 if (onStatus_) 716 onStatus_(status_, info_); 717 718 switch(status_.error) { 719 case ErrorCodes.ER_DUP_ENTRY_WITH_KEY_NAME: 720 case ErrorCodes.ER_DUP_ENTRY: 721 throw new MySQLDuplicateEntryException(cast(string)info_); 722 default: 723 throw new MySQLErrorException(cast(string)info_); 724 } 725 default: 726 throw new MySQLProtocolException("Unexpected packet format"); 727 } 728 } 729 730 void info(const(char)[] value) { 731 info_.length = value.length; 732 info_[0..$] = value; 733 } 734 735 void skipColumnDef(InputPacket packet, Commands cmd) { 736 packet.skip(cast(size_t)packet.eatLenEnc()); // catalog 737 packet.skip(cast(size_t)packet.eatLenEnc()); // schema 738 packet.skip(cast(size_t)packet.eatLenEnc()); // table 739 packet.skip(cast(size_t)packet.eatLenEnc()); // original_table 740 packet.skip(cast(size_t)packet.eatLenEnc()); // name 741 packet.skip(cast(size_t)packet.eatLenEnc()); // original_name 742 packet.skipLenEnc(); // next_length 743 packet.skip(10); // 2 + 4 + 1 + 2 + 1 // charset, length, type, flags, decimals 744 packet.expect!ushort(0); 745 746 if (cmd == Commands.COM_FIELD_LIST) 747 packet.skip(cast(size_t)packet.eatLenEnc());// default values 748 } 749 750 void columnDef(InputPacket packet, Commands cmd, ref MySQLColumn def) { 751 packet.skip(cast(size_t)packet.eatLenEnc()); // catalog 752 packet.skip(cast(size_t)packet.eatLenEnc()); // schema 753 packet.skip(cast(size_t)packet.eatLenEnc()); // table 754 packet.skip(cast(size_t)packet.eatLenEnc()); // original_table 755 auto len = cast(size_t)packet.eatLenEnc(); 756 columns_ ~= packet.eat!(const(char)[])(len); 757 def.name = columns_[$-len..$]; 758 packet.skip(cast(size_t)packet.eatLenEnc()); // original_name 759 packet.skipLenEnc(); // next_length 760 packet.skip(2); // charset 761 def.length = packet.eat!uint; 762 def.type = cast(ColumnTypes)packet.eat!ubyte; 763 def.flags = packet.eat!ushort; 764 def.decimals = packet.eat!ubyte; 765 766 packet.expect!ushort(0); 767 768 if (cmd == Commands.COM_FIELD_LIST) 769 packet.skip(cast(size_t)packet.eatLenEnc());// default values 770 } 771 772 void columnDefs(size_t count, Commands cmd, ref MySQLColumn[] defs) { 773 defs.length = count; 774 foreach (i; 0..count) 775 columnDef(retrieve(), cmd, defs[i]); 776 } 777 778 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 1) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLRow)) { 779 static if (is(ReturnType!(RowHandler) == void)) { 780 handler(row); 781 return true; 782 } else { 783 return handler(row); // return type must be bool 784 } 785 } 786 787 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) { 788 static if (is(ReturnType!(RowHandler) == void)) { 789 handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); 790 return true; 791 } else { 792 return handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); // return type must be bool 793 } 794 } 795 796 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) { 797 static if (is(ReturnType!(RowHandler) == void)) { 798 handler(header, row); 799 return true; 800 } else { 801 return handler(header, row); // return type must be bool 802 } 803 } 804 805 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 3) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[2] == MySQLRow)) { 806 static if (is(ReturnType!(RowHandler) == void)) { 807 handler(i, header, row); 808 return true; 809 } else { 810 return handler(i, header, row); // return type must be bool 811 } 812 } 813 814 void resultSetRow(InputPacket packet, Commands cmd, MySQLHeader header, ref MySQLRow row) { 815 assert(row.columns.length == header.length); 816 817 packet.expect!ubyte(0); 818 auto nulls = packet.eat!(ubyte[])((header.length + 2 + 7) >> 3); 819 foreach (i, ref column; header) { 820 const auto index = (i + 2) >> 3; // bit offset of 2 821 const auto bit = (i + 2) & 7; 822 823 if ((nulls[index] & (1 << bit)) == 0) { 824 eatValue(packet, column, row.get_(i)); 825 } else { 826 auto signed = (column.flags & FieldFlags.UNSIGNED_FLAG) == 0; 827 row.get_(i) = MySQLValue(column.name, ColumnTypes.MYSQL_TYPE_NULL, signed, null, 0); 828 } 829 } 830 assert(packet.empty); 831 } 832 833 void resultSet(RowHandler)(InputPacket packet, uint stmt, Commands cmd, RowHandler handler) { 834 columns_.length = 0; 835 836 auto columns = cast(size_t)packet.eatLenEnc(); 837 columnDefs(columns, cmd, header_); 838 row_.header_(header_); 839 840 auto status = retrieve(); 841 if (status.peek!ubyte == StatusPackets.ERR_Packet) 842 eatStatus(status); 843 844 size_t index = 0; 845 auto statusFlags = eatEOF(status); 846 if (statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) { 847 uint[2] data = [ stmt, 4096 ]; // todo: make setting - rows per fetch 848 while (statusFlags & (StatusFlags.SERVER_STATUS_CURSOR_EXISTS | StatusFlags.SERVER_MORE_RESULTS_EXISTS)) { 849 send(Commands.COM_STMT_FETCH, data); 850 851 auto answer = retrieve(); 852 if (answer.peek!ubyte == StatusPackets.ERR_Packet) 853 eatStatus(answer); 854 855 auto row = answer.empty ? retrieve() : answer; 856 while (true) { 857 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 858 statusFlags = eatEOF(row); 859 break; 860 } 861 862 resultSetRow(row, Commands.COM_STMT_FETCH, header_, row_); 863 if (!callHandler(handler, index++, header_, row_)) { 864 discardUntilEOF(retrieve()); 865 statusFlags = 0; 866 break; 867 } 868 row = retrieve(); 869 } 870 } 871 } else { 872 while (true) { 873 auto row = retrieve(); 874 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 875 eatEOF(row); 876 break; 877 } 878 879 resultSetRow(row, cmd, header_, row_); 880 if (!callHandler(handler, index++, header_, row_)) { 881 discardUntilEOF(retrieve()); 882 break; 883 } 884 } 885 } 886 } 887 888 void resultSetRowText(InputPacket packet, Commands cmd, MySQLHeader header, ref MySQLRow row) { 889 assert(row.columns.length == header.length); 890 891 foreach(i, ref column; header) { 892 if (packet.peek!ubyte != 0xfb) { 893 eatValueText(packet, column, row.get_(i)); 894 } else { 895 packet.skip(1); 896 auto signed = (column.flags & FieldFlags.UNSIGNED_FLAG) == 0; 897 row.get_(i) = MySQLValue(column.name, ColumnTypes.MYSQL_TYPE_NULL, signed, null, 0); 898 } 899 } 900 assert(packet.empty); 901 } 902 903 void resultSetText(RowHandler)(InputPacket packet, Commands cmd, RowHandler handler) { 904 columns_.length = 0; 905 906 auto columns = cast(size_t)packet.eatLenEnc(); 907 columnDefs(columns, cmd, header_); 908 row_.header_(header_); 909 910 eatEOF(retrieve()); 911 912 size_t index = 0; 913 while (true) { 914 auto row = retrieve(); 915 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 916 eatEOF(row); 917 break; 918 } else if (row.peek!ubyte == StatusPackets.ERR_Packet) { 919 eatStatus(row); 920 break; 921 } 922 923 resultSetRowText(row, cmd, header_, row_); 924 if (!callHandler(handler, index++, header_, row_)) { 925 discardUntilEOF(retrieve()); 926 break; 927 } 928 } 929 } 930 931 void discardAll(InputPacket packet, Commands cmd) { 932 auto columns = cast(size_t)packet.eatLenEnc(); 933 columnDefs(columns, cmd, header_); 934 935 auto statusFlags = eatEOF(retrieve()); 936 if ((statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) == 0) { 937 while (true) { 938 auto row = retrieve(); 939 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 940 eatEOF(row); 941 break; 942 } 943 } 944 } 945 } 946 947 void discardUntilEOF(InputPacket packet) { 948 while (true) { 949 if (packet.peek!ubyte == StatusPackets.EOF_Packet) { 950 eatEOF(packet); 951 break; 952 } 953 packet = retrieve(); 954 } 955 } 956 957 auto eatEOF(InputPacket packet) { 958 auto id = packet.eat!ubyte; 959 if (id != StatusPackets.EOF_Packet) 960 throw new MySQLProtocolException("Unexpected packet format"); 961 962 status_.error = 0; 963 status_.warnings = packet.eat!ushort(); 964 status_.flags = packet.eat!ushort(); 965 info([]); 966 967 if (onStatus_) 968 onStatus_(status_, info_); 969 970 return status_.flags; 971 } 972 973 auto prepareSQL(Args...)(const(char)[] sql, Args args) { 974 auto estimated = sql.length; 975 size_t argCount; 976 977 foreach(i, arg; args) { 978 static if (is(typeof(arg) == typeof(null))) { 979 ++argCount; 980 estimated += 4; 981 } else static if (is(Unqual!(typeof(arg)) == MySQLValue)) { 982 ++argCount; 983 final switch(arg.type) with (ColumnTypes) { 984 case MYSQL_TYPE_NULL: 985 estimated += 4; 986 break; 987 case MYSQL_TYPE_TINY: 988 estimated += 4; 989 break; 990 case MYSQL_TYPE_YEAR: 991 case MYSQL_TYPE_SHORT: 992 estimated += 6; 993 break; 994 case MYSQL_TYPE_INT24: 995 case MYSQL_TYPE_LONG: 996 estimated += 6; 997 break; 998 case MYSQL_TYPE_LONGLONG: 999 estimated += 8; 1000 break; 1001 case MYSQL_TYPE_FLOAT: 1002 estimated += 8; 1003 break; 1004 case MYSQL_TYPE_DOUBLE: 1005 estimated += 8; 1006 break; 1007 case MYSQL_TYPE_SET: 1008 case MYSQL_TYPE_ENUM: 1009 case MYSQL_TYPE_VARCHAR: 1010 case MYSQL_TYPE_VAR_STRING: 1011 case MYSQL_TYPE_STRING: 1012 case MYSQL_TYPE_JSON: 1013 case MYSQL_TYPE_NEWDECIMAL: 1014 case MYSQL_TYPE_DECIMAL: 1015 case MYSQL_TYPE_TINY_BLOB: 1016 case MYSQL_TYPE_MEDIUM_BLOB: 1017 case MYSQL_TYPE_LONG_BLOB: 1018 case MYSQL_TYPE_BLOB: 1019 case MYSQL_TYPE_BIT: 1020 case MYSQL_TYPE_GEOMETRY: 1021 estimated += 2 + arg.peek!(const(char)[]).length; 1022 break; 1023 case MYSQL_TYPE_TIME: 1024 case MYSQL_TYPE_TIME2: 1025 estimated += 18; 1026 break; 1027 case MYSQL_TYPE_DATE: 1028 case MYSQL_TYPE_NEWDATE: 1029 case MYSQL_TYPE_DATETIME: 1030 case MYSQL_TYPE_DATETIME2: 1031 case MYSQL_TYPE_TIMESTAMP: 1032 case MYSQL_TYPE_TIMESTAMP2: 1033 estimated += 20; 1034 break; 1035 } 1036 } else static if (isArray!(typeof(arg)) && !isSomeString!(typeof(arg))) { 1037 argCount += arg.length; 1038 estimated += arg.length * 6; 1039 } else static if (isSomeString!(typeof(arg)) || is(Unqual!(typeof(arg)) == MySQLRawString) || is(Unqual!(typeof(arg)) == MySQLFragment) || is(Unqual!(typeof(arg)) == MySQLBinary)) { 1040 ++argCount; 1041 estimated += 2 + arg.length; 1042 } else { 1043 ++argCount; 1044 estimated += 6; 1045 } 1046 } 1047 1048 sql_.clear; 1049 sql_.reserve(max(8192, estimated)); 1050 1051 alias AppendFunc = bool function(ref Appender!(char[]), ref const(char)[] sql, ref size_t, const(void)*) @safe pure nothrow; 1052 AppendFunc[Args.length] funcs; 1053 const(void)*[Args.length] addrs; 1054 1055 foreach (i, Arg; Args) { 1056 funcs[i] = () @trusted { return cast(AppendFunc)&appendNextValue!(Arg); }(); 1057 addrs[i] = (ref x)@trusted{ return cast(const void*)&x; }(args[i]); 1058 } 1059 1060 size_t indexArg; 1061 foreach (i; 0..Args.length) { 1062 if (!funcs[i](sql_, sql, indexArg, addrs[i])) 1063 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", argCount, indexArg)); 1064 } 1065 1066 if (copyUpToNext(sql_, sql)) { 1067 ++indexArg; 1068 while (copyUpToNext(sql_, sql)) 1069 ++indexArg; 1070 throw new MySQLErrorException(format("Wrong number of parameters for query. Got %d but expected %d.", argCount, indexArg)); 1071 } 1072 1073 return sql_.data; 1074 } 1075 1076 SocketType socket_; 1077 MySQLHeader header_; 1078 MySQLRow row_; 1079 char[] columns_; 1080 char[] info_; 1081 char[] schema_; 1082 ubyte[] in_; 1083 ubyte[] out_; 1084 ubyte seq_; 1085 Appender!(char[]) sql_; 1086 1087 OnStatusCallback onStatus_; 1088 CapabilityFlags caps_; 1089 ConnectionStatus status_; 1090 ConnectionSettings settings_; 1091 ServerInfo server_; 1092 } 1093 1094 private auto copyUpToNext(ref Appender!(char[]) app, ref const(char)[] sql) { 1095 size_t offset; 1096 dchar quote = '\0'; 1097 1098 while (offset < sql.length) { 1099 auto ch = decode!(UseReplacementDchar.no)(sql, offset); 1100 switch (ch) { 1101 case '?': 1102 if (!quote) { 1103 app.put(sql[0..offset - 1]); 1104 sql = sql[offset..$]; 1105 return true; 1106 } else { 1107 goto default; 1108 } 1109 case '\'': 1110 case '\"': 1111 case '`': 1112 if (quote == ch) { 1113 quote = '\0'; 1114 } else if (!quote) { 1115 quote = ch; 1116 } 1117 goto default; 1118 case '\\': 1119 if (quote && (offset < sql.length)) 1120 decode!(UseReplacementDchar.no)(sql, offset); 1121 goto default; 1122 default: 1123 break; 1124 } 1125 } 1126 app.put(sql[0..offset]); 1127 sql = sql[offset..$]; 1128 return false; 1129 } 1130 1131 private bool appendNextValue(T)(ref Appender!(char[]) app, ref const(char)[] sql, ref size_t indexArg, const(void)* arg) { 1132 static if (isArray!T && !isSomeString!T) { 1133 foreach (i, ref v; *cast(T*)arg) { 1134 if (copyUpToNext(app, sql)) { 1135 appendValue(app, v); 1136 ++indexArg; 1137 } else { 1138 return false; 1139 } 1140 } 1141 } else { 1142 if (copyUpToNext(app, sql)) { 1143 appendValue(app, *cast(T*)arg); 1144 ++indexArg; 1145 } else { 1146 return false; 1147 } 1148 } 1149 return true; 1150 }