1 module mysql.connection; 2 3 import std.array; 4 import std.functional; 5 import std.string; 6 import std.traits; 7 8 public import mysql.exception; 9 import mysql.packet; 10 public import mysql.protocol; 11 public import mysql.type; 12 13 14 immutable CapabilityFlags DefaultClientCaps = CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_LONG_FLAG | 15 CapabilityFlags.CLIENT_CONNECT_WITH_DB | CapabilityFlags.CLIENT_PROTOCOL_41 | CapabilityFlags.CLIENT_SECURE_CONNECTION; 16 17 18 struct ConnectionSettings { 19 CapabilityFlags caps = DefaultClientCaps; 20 21 const(char)[] host; 22 const(char)[] user; 23 const(char)[] pwd; 24 const(char)[] db; 25 ushort port = 3306; 26 } 27 28 29 struct ConnectionStatus { 30 CapabilityFlags caps = cast(CapabilityFlags)0; 31 32 ulong affected = 0; 33 ulong insertID = 0; 34 ushort flags = 0; 35 ushort error = 0; 36 ushort warnings = 0; 37 } 38 39 40 struct ServerInfo { 41 const(char)[] versionString; 42 ubyte protocol; 43 ubyte charSet; 44 ushort status; 45 uint connection; 46 uint caps; 47 } 48 49 50 struct PreparedStatement { 51 package: 52 uint id; // todo: investigate if it's really necessary to close statements explicitly 53 uint params; 54 } 55 56 57 struct Connection(SocketType) { 58 void connect(string connectionString) { 59 connectionSettings(connectionString); 60 connect(); 61 } 62 63 void connect(const(char)[] host, ushort port, const(char)[] user, const(char)[] pwd, const(char)[] db, CapabilityFlags caps = DefaultClientCaps) { 64 settings_.host = host; 65 settings_.user = user; 66 settings_.pwd = pwd; 67 settings_.db = db; 68 settings_.port = port; 69 settings_.caps = caps | CapabilityFlags.CLIENT_LONG_PASSWORD | CapabilityFlags.CLIENT_PROTOCOL_41; 70 71 connect(); 72 } 73 74 void use(const(char)[] db) { 75 send(Commands.COM_INIT_DB, db); 76 status(retrieve()); 77 } 78 79 void ping() { 80 send(Commands.COM_PING); 81 status(retrieve()); 82 } 83 84 void refresh() { 85 send(Commands.COM_REFRESH); 86 status(retrieve()); 87 } 88 89 void reset() { 90 send(Commands.COM_RESET_CONNECTION); 91 status(retrieve()); 92 } 93 94 const(char)[] statistics() { 95 send(Commands.COM_STATISTICS); 96 97 auto answer = retrieve(); 98 return answer.eat!(const(char)[])(answer.remaining); 99 } 100 101 auto prepare(const(char)[] sql) { 102 send(Commands.COM_STMT_PREPARE, sql); 103 104 auto answer = retrieve(); 105 check(answer); 106 107 answer.expect!ubyte(0); 108 109 auto id = answer.eat!uint; 110 auto columns = answer.eat!ushort; 111 auto params = answer.eat!ushort; 112 answer.expect!ubyte(0); 113 114 auto warnings = answer.eat!ushort; 115 116 if (params) { 117 MySQLColumn def; 118 foreach (i; 0..params) 119 columnDef(retrieve(), Commands.COM_STMT_PREPARE, def); 120 121 skipEOF(retrieve()); 122 } 123 124 if (columns) { 125 MySQLColumn def; 126 foreach (i; 0..columns) 127 columnDef(retrieve(), Commands.COM_STMT_PREPARE, def); 128 129 skipEOF(retrieve()); 130 } 131 132 return PreparedStatement(id, params); 133 } 134 135 void execute(Args...)(const(char)[] stmt, Args args) { 136 auto id = prepare(stmt); 137 execute(id, args); 138 close(id); 139 } 140 141 void execute(Args...)(PreparedStatement stmt, Args args) { 142 ensureConnected(); 143 144 seq_ = 0; 145 auto packet = OutputPacket(&out_); 146 packet.put!ubyte(Commands.COM_STMT_EXECUTE); 147 packet.put!uint(stmt.id); 148 packet.put!ubyte(Cursors.CURSOR_TYPE_READ_ONLY); 149 packet.put!uint(1); 150 151 static if (args.length == 0) { 152 enum shouldDiscard = true; 153 } else { 154 enum shouldDiscard = !isCallable!(args[args.length - 1]); 155 } 156 157 enum argCount = shouldDiscard ? args.length : (args.length - 1); 158 159 if (argCount != stmt.params) 160 throw new MySQLErrorException("Wrong number of parameters for query"); 161 162 static if (argCount) { 163 ubyte[1024] nulls; 164 foreach(i, arg; args) { 165 const auto index = i >> 3; 166 const auto bit = i & 7; 167 168 static if (is(typeof(arg) == typeof(null))) { 169 nulls[index] = nulls[index] | (1 << bit); 170 } 171 } 172 173 packet.put(nulls[0..((args.length + 7) >> 3)]); 174 packet.put!ubyte(1); 175 176 foreach (arg; args[0..argCount]) 177 putValueType(packet, arg); 178 179 foreach (arg; args[0..argCount]) { 180 static if (!is(typeof(arg) == typeof(null))) { 181 putValue(packet, arg); 182 } 183 } 184 } 185 186 packet.finalize(seq_); 187 ++seq_; 188 189 socket_.write(packet.get()); 190 191 auto answer = retrieve(); 192 if (isStatus(answer)) { 193 status(answer); 194 } else { 195 static if (!shouldDiscard) { 196 resultSet(answer, stmt.id, Commands.COM_STMT_EXECUTE, args[args.length - 1]); 197 } else { 198 discardAll(answer, Commands.COM_STMT_EXECUTE); 199 } 200 } 201 } 202 203 void close(PreparedStatement stmt) { 204 uint[1] data = [ stmt.id ]; 205 send(Commands.COM_STMT_CLOSE, data); 206 } 207 208 ulong insertID() { 209 return cast(size_t)status_.insertID; 210 } 211 212 ulong affected() { 213 return cast(size_t)status_.affected; 214 } 215 216 size_t warnings() { 217 return status_.warnings; 218 } 219 220 size_t error() { 221 return status_.error; 222 } 223 224 const(char)[] status() { 225 return info_; 226 } 227 228 void disconnect() { 229 socket_.close(); 230 } 231 232 ~this() { 233 disconnect(); 234 } 235 236 private: 237 void connect() { 238 socket_.connect(settings_.host, settings_.port); 239 240 seq_ = 0; 241 handshake(retrieve()); 242 } 243 244 void send(T)(Commands cmd, T[] data) { 245 send(cmd, cast(ubyte*)data.ptr, data.length * T.sizeof); 246 } 247 248 void send(Commands cmd, ubyte* data = null, size_t length = 0) { 249 if(!socket_.connected) { 250 connect(); 251 } else { 252 seq_ = 0; 253 } 254 255 auto header = OutputPacket(&out_); 256 header.put!ubyte(cmd); 257 header.finalize(seq_, length); 258 ++seq_; 259 260 socket_.write(header.get()); 261 if (length) 262 socket_.write(data[0..length]); 263 } 264 265 void ensureConnected() { 266 if(!socket_.connected) 267 connect(); 268 } 269 270 bool isStatus(InputPacket packet) { 271 auto id = packet.peek!ubyte; 272 switch (id) { 273 case StatusPackets.ERR_Packet: 274 case StatusPackets.OK_Packet: 275 return 1; 276 default: 277 return false; 278 } 279 } 280 281 void check(InputPacket packet) { 282 auto id = packet.peek!ubyte; 283 switch (id) { 284 case StatusPackets.ERR_Packet: 285 case StatusPackets.OK_Packet: 286 status(packet); 287 break; 288 default: 289 break; 290 } 291 } 292 293 InputPacket retrieve() { 294 scope(failure) disconnect(); 295 296 ubyte[4] header; 297 socket_.read(header); 298 299 auto len = header[0] | (header[1] << 8) | (header[2] << 16); 300 auto seq = header[3]; 301 302 if (seq != seq_) 303 throw new MySQLConnectionException("Out of order packet received"); 304 305 ++seq_; 306 307 in_.length = len; 308 socket_.read(in_); 309 310 if (in_.length != len) 311 throw new MySQLConnectionException("Wrong number of bytes read"); 312 313 return InputPacket(&in_); 314 } 315 316 void handshake(InputPacket packet) { 317 scope(failure) disconnect(); 318 319 server_.protocol = packet.eat!ubyte; 320 server_.versionString = packet.eat!(const(char)[])(packet.countUntil(0, true)); 321 packet.skip(1); 322 323 server_.connection = packet.eat!uint; 324 325 const auto authLengthStart = 8; 326 size_t authLength = authLengthStart; 327 328 ubyte[256] auth; 329 auth[0..authLength] = packet.eat!(ubyte[])(authLength); 330 331 packet.expect!ubyte(0); 332 333 server_.caps = packet.eat!ushort; 334 335 if (!packet.empty) { 336 server_.charSet = packet.eat!ubyte; 337 server_.status = packet.eat!ushort; 338 server_.caps |= packet.eat!ushort << 16; 339 server_.caps |= CapabilityFlags.CLIENT_LONG_PASSWORD; 340 341 if ((server_.caps & CapabilityFlags.CLIENT_PROTOCOL_41) == 0) 342 throw new MySQLProtocolException("Server doesn't support protocol v4.1"); 343 344 if (server_.caps & CapabilityFlags.CLIENT_SECURE_CONNECTION) { 345 packet.skip(1); 346 } else { 347 packet.expect!ubyte(0); 348 } 349 350 packet.skip(10); 351 352 authLength += packet.countUntil(0, true); 353 if (authLength > auth.length) 354 throw new MySQLConnectionException("Bad packet format"); 355 356 auth[authLengthStart..authLength] = packet.eat!(ubyte[])(authLength - authLengthStart); 357 358 packet.expect!ubyte(0); 359 } 360 361 ubyte[20] token; 362 { 363 import std.digest.sha; 364 365 auto pass = sha1Of(cast(const(ubyte)[])settings_.pwd); 366 token = sha1Of(pass); 367 368 SHA1 sha1; 369 sha1.start(); 370 sha1.put(auth[0..authLength]); 371 sha1.put(token); 372 token = sha1.finish(); 373 374 foreach (i; 0..20) 375 token[i] = token[i] ^ pass[i]; 376 } 377 378 status_.caps = cast(CapabilityFlags)(settings_.caps & server_.caps); 379 380 auto reply = OutputPacket(&out_); 381 reply.reserve(64 + settings_.user.length + settings_.pwd.length + settings_.db.length); 382 383 reply.put!uint(status_.caps); 384 reply.put!uint(1); 385 reply.put!ubyte(33); 386 reply.fill(0, 23); 387 388 reply.put(settings_.user); 389 reply.put!ubyte(0); 390 391 if (settings_.pwd.length) { 392 if (status_.caps & CapabilityFlags.CLIENT_SECURE_CONNECTION) { 393 reply.put!ubyte(token.length); 394 reply.put(token); 395 } else { 396 reply.put(token); 397 reply.put!ubyte(0); 398 } 399 } else { 400 reply.put!ubyte(0); 401 } 402 403 if (settings_.db.length && (status_.caps & CapabilityFlags.CLIENT_CONNECT_WITH_DB)) { 404 reply.put(settings_.db); 405 reply.put!ubyte(0); 406 } 407 408 reply.finalize(seq_); 409 ++seq_; 410 411 socket_.write(reply.get()); 412 413 status(retrieve()); 414 } 415 416 void status(InputPacket packet) { 417 auto id = packet.eat!ubyte; 418 419 switch (id) { 420 case StatusPackets.OK_Packet: 421 status_.error = 0; 422 status_.affected = packet.eatLenEnc(); 423 status_.insertID = packet.eatLenEnc(); 424 status_.flags = packet.eat!ushort; 425 status_.warnings = packet.eat!ushort; 426 427 if (status_.caps & CapabilityFlags.CLIENT_SESSION_TRACK) { 428 info(packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc())); 429 packet.skip(1); 430 431 if (status_.flags & StatusFlags.SERVER_SESSION_STATE_CHANGED) { 432 packet.skip(cast(size_t)packet.eatLenEnc()); 433 packet.skip(1); 434 } 435 } else if (!packet.empty) { 436 info(packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc())); 437 } 438 break; 439 case StatusPackets.EOF_Packet: 440 status_.warnings = packet.eat!ushort; 441 status_.flags = packet.eat!ushort; 442 info([]); 443 break; 444 case StatusPackets.ERR_Packet: 445 status_.flags = 0; 446 status_.error = packet.eat!ushort; 447 packet.skip(6); 448 info(packet.eat!(const(char)[])(packet.remaining)); 449 450 throw new MySQLErrorException(cast(string)info_); 451 default: 452 throw new MySQLProtocolException("Unexpected packet format"); 453 } 454 } 455 456 void info(const(char)[] value) { 457 info_.length = value.length; 458 info_[0..$] = value; 459 } 460 461 void columnDef(InputPacket packet, Commands cmd, ref MySQLColumn def) { 462 auto catalog = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()); 463 auto schema = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()); 464 auto table = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()); 465 auto org_table = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()); 466 def.name = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()).idup; // todo: fix allocation 467 auto org_name = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()); 468 auto next_length = cast(size_t)packet.eatLenEnc(); 469 auto char_set = packet.eat!ushort; 470 def.length = packet.eat!uint; 471 def.type = cast(ColumnTypes)packet.eat!ubyte; 472 def.flags = packet.eat!ushort; 473 def.decimals = packet.eat!ubyte; 474 475 packet.expect!ushort(0); 476 477 if (cmd == Commands.COM_FIELD_LIST) { 478 auto default_values = packet.eat!(const(char)[])(cast(size_t)packet.eatLenEnc()); 479 } 480 } 481 482 auto columnDefs(size_t count, Commands cmd) { 483 header_.length = count; 484 foreach (i; 0..count) 485 columnDef(retrieve(), cmd, header_[i]); 486 return header_; 487 } 488 489 void resultSetRow(InputPacket packet, Commands cmd, MySQLHeader header, MySQLRow row) { 490 assert(row.length == header.length); 491 492 packet.expect!ubyte(0); 493 auto nulls = packet.eat!(ubyte[])((header.length + 2 + 7) >> 3); 494 foreach (i, column; header) { 495 const auto index = (i + 2) >> 3; // bit offset of 2 496 const auto bit = (i + 2) & 7; 497 498 if ((nulls[index] & (1 << bit)) == 0) { 499 row[i] = eatValue(packet, column); 500 } else { 501 row[i].nullify(); 502 } 503 } 504 assert(packet.empty); 505 } 506 507 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 1) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLRow)) { 508 static if (is(ReturnType!(RowHandler) == void)) { 509 handler(row); 510 return true; 511 } else { 512 return handler(row); // return type must be bool 513 } 514 } 515 516 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) { 517 static if (is(ReturnType!(RowHandler) == void)) { 518 handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); 519 return true; 520 } else { 521 return handler(cast(ParameterTypeTuple!(RowHandler)[0])i, row); // return type must be bool 522 } 523 } 524 525 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 2) && is(ParameterTypeTuple!(RowHandler)[0] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLRow)) { 526 static if (is(ReturnType!(RowHandler) == void)) { 527 handler(header, row); 528 return true; 529 } else { 530 return handler(header, row); // return type must be bool 531 } 532 } 533 534 bool callHandler(RowHandler)(RowHandler handler, size_t i, MySQLHeader header, MySQLRow row) if ((ParameterTypeTuple!(RowHandler).length == 3) && isNumeric!(ParameterTypeTuple!(RowHandler)[0]) && is(ParameterTypeTuple!(RowHandler)[1] == MySQLHeader) && is(ParameterTypeTuple!(RowHandler)[2] == MySQLRow)) { 535 static if (is(ReturnType!(RowHandler) == void)) { 536 handler(i, header, row); 537 return true; 538 } else { 539 return handler(i, header, row); // return type must be bool 540 } 541 } 542 543 void resultSet(RowHandler)(InputPacket packet, uint stmt, Commands cmd, RowHandler handler) { 544 auto columns = cast(size_t)packet.eatLenEnc(); 545 auto header = columnDefs(columns, cmd); 546 row_.length = columns; 547 548 size_t index = 0; 549 auto statusFlags = skipEOF(retrieve()); 550 if (statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) { 551 uint[2] data = [ stmt, 4096 ]; // todo: make setting - rows per fetch 552 while (statusFlags & (StatusFlags.SERVER_STATUS_CURSOR_EXISTS | StatusFlags.SERVER_MORE_RESULTS_EXISTS)) { 553 send(Commands.COM_STMT_FETCH, data); 554 555 auto answer = retrieve(); 556 if (answer.peek!ubyte == StatusPackets.ERR_Packet) 557 check(answer); 558 559 auto row = answer.empty ? retrieve() : answer; 560 while (true) { 561 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 562 statusFlags = skipEOF(row); 563 break; 564 } 565 566 resultSetRow(row, Commands.COM_STMT_FETCH, header, row_); 567 if (!callHandler(handler, index++, header, row_)) { 568 discardUntilEOF(retrieve()); 569 statusFlags = 0; 570 break; 571 } 572 row = retrieve(); 573 } 574 } 575 } else { 576 auto row = retrieve(); 577 while (true) { 578 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 579 status(row); 580 break; 581 } 582 583 resultSetRow(row, cmd, header, row_); 584 if (!callHandler(handler, index++, header, row_)) { 585 discardUntilEOF(retrieve()); 586 break; 587 } 588 589 row = retrieve(); 590 } 591 } 592 } 593 594 void discardAll(InputPacket packet, Commands cmd) { 595 auto columns = cast(size_t)packet.eatLenEnc(); 596 auto defs = columnDefs(columns, cmd); 597 598 auto statusFlags = skipEOF(retrieve()); 599 if ((statusFlags & StatusFlags.SERVER_STATUS_CURSOR_EXISTS) == 0) { 600 while (true) { 601 auto row = retrieve(); 602 if (row.peek!ubyte == StatusPackets.EOF_Packet) { 603 status(row); 604 break; 605 } 606 } 607 } 608 } 609 610 void discardUntilEOF(InputPacket packet) { 611 if (packet.peek!ubyte == StatusPackets.EOF_Packet) { 612 status(packet); 613 return; 614 } else { 615 while (true) { 616 if (packet.peek!ubyte == StatusPackets.EOF_Packet) { 617 status(packet); 618 break; 619 } 620 packet = retrieve(); 621 } 622 } 623 } 624 625 auto skipEOF(InputPacket packet) { 626 auto id = packet.eat!ubyte; 627 if (id != StatusPackets.EOF_Packet) 628 throw new MySQLProtocolException("Unexpected packet format"); 629 630 packet.skip(2); 631 return packet.eat!ushort(); 632 } 633 634 void connectionSettings(const(char)[] connectionString) { 635 import std.conv; 636 637 auto remaining = connectionString; 638 639 auto indexValue = remaining.indexOf("="); 640 while (!remaining.empty) { 641 auto indexValueEnd = remaining.indexOf(";", indexValue); 642 if (indexValueEnd <= 0) 643 indexValueEnd = remaining.length; 644 645 auto name = strip(remaining[0..indexValue]); 646 auto value = strip(remaining[indexValue+1..indexValueEnd]); 647 648 switch (name) { 649 case "host": 650 settings_.host = value; 651 break; 652 case "user": 653 settings_.user = value; 654 break; 655 case "pwd": 656 settings_.pwd = value; 657 break; 658 case "db": 659 settings_.db = value; 660 break; 661 case "port": 662 settings_.port = to!ushort(value); 663 break; 664 default: 665 throw new MySQLException("Bad connection string: " ~ cast(string)connectionString); 666 } 667 668 if (indexValueEnd == remaining.length) 669 return; 670 671 remaining = remaining[indexValueEnd+1..$]; 672 indexValue = remaining.indexOf("="); 673 } 674 675 throw new MySQLException("Bad connection string: " ~ cast(string)connectionString); 676 } 677 678 SocketType socket_; 679 MySQLHeader header_; 680 MySQLRow row_; 681 char[] info_; 682 ubyte[] in_; 683 ubyte[] out_; 684 ubyte seq_ = 0; 685 686 ConnectionStatus status_; 687 ConnectionSettings settings_; 688 ServerInfo server_; 689 }