1 module mysql.packet; 2 3 4 import std.algorithm; 5 import std.traits; 6 7 import mysql.exception; 8 9 10 struct InputPacket { 11 @disable this(); 12 13 this(ubyte[]* buffer) { 14 buffer_ = buffer; 15 in_ = *buffer_; 16 } 17 18 T peek(T)() if (!isArray!T) { 19 assert(T.sizeof <= in_.length); 20 return *(cast(T*)in_.ptr); 21 } 22 23 T eat(T)() if (!isArray!T) { 24 assert(T.sizeof <= in_.length); 25 auto ptr = cast(T*)in_.ptr; 26 in_ = in_[T.sizeof..$]; 27 return *ptr; 28 } 29 30 T peek(T)(size_t count) if (isArray!T) { 31 alias ValueType = typeof(Type.init[0]); 32 33 assert(ValueType.sizeof * count <= in_.length); 34 auto ptr = cast(ValueType*)in_.ptr; 35 return ptr[0..count]; 36 } 37 38 T eat(T)(size_t count) if (isArray!T) { 39 alias ValueType = typeof(T.init[0]); 40 41 assert(ValueType.sizeof * count <= in_.length); 42 auto ptr = cast(ValueType*)in_.ptr; 43 in_ = in_[ValueType.sizeof * count..$]; 44 return ptr[0..count]; 45 } 46 47 void expect(T)(T x) { 48 if (x != eat!T) 49 throw new MySQLProtocolException("Bad packet format"); 50 } 51 52 void skip(size_t count) { 53 assert(count <= in_.length); 54 in_ = in_[count..$]; 55 } 56 57 auto countUntil(ubyte x, bool expect) { 58 auto index = in_.countUntil(x); 59 if (expect) { 60 if ((index < 0) || (in_[index] != x)) 61 throw new MySQLProtocolException("Bad packet format"); 62 } 63 return index; 64 } 65 66 void skipLenEnc() { 67 auto header = eat!ubyte; 68 if (header >= 0xfb) { 69 switch(header) { 70 case 0xfb: 71 return; 72 case 0xfc: 73 skip(2); 74 return; 75 case 0xfd: 76 skip(3); 77 return; 78 case 0xfe: 79 skip(8); 80 return; 81 default: 82 throw new MySQLProtocolException("Bad packet format"); 83 } 84 } 85 } 86 87 ulong eatLenEnc() { 88 auto header = eat!ubyte; 89 if (header < 0xfb) 90 return header; 91 92 ulong lo; 93 ulong hi; 94 95 switch(header) { 96 case 0xfb: 97 return 0; 98 case 0xfc: 99 return eat!ushort; 100 case 0xfd: 101 lo = eat!ubyte; 102 hi = eat!ushort; 103 return lo | (hi << 8); 104 case 0xfe: 105 lo = eat!uint; 106 hi = eat!uint; 107 return lo | (hi << 32); 108 default: 109 throw new MySQLProtocolException("Bad packet format"); 110 } 111 } 112 113 auto remaining() const { 114 return in_.length; 115 } 116 117 bool empty() const { 118 return in_.length == 0; 119 } 120 protected: 121 ubyte[]* buffer_; 122 ubyte[] in_; 123 } 124 125 126 struct OutputPacket { 127 @disable this(); 128 129 this(ubyte[]* buffer) { 130 buffer_ = buffer; 131 out_ = buffer_.ptr + 4; 132 } 133 134 void put(T)(T x) if (!isArray!T) { 135 put(offset_, x); 136 } 137 138 void put(T)(T x) if (isArray!T) { 139 put(offset_, x); 140 } 141 142 void put(T)(size_t offset, T x) if (!isArray!T) { 143 grow(offset, T.sizeof); 144 145 *(cast(T*)(out_ + offset)) = x; 146 offset_ = max(offset + T.sizeof, offset_); 147 } 148 149 void put(T)(size_t offset, T x) if (isArray!T) { 150 alias ValueType = Unqual!(typeof(T.init[0])); 151 152 grow(offset, ValueType.sizeof * x.length); 153 154 (cast(ValueType*)(out_ + offset))[0..x.length] = x; 155 offset_ = max(offset + (ValueType.sizeof * x.length), offset_); 156 } 157 158 void putLenEnc(ulong x) { 159 if (x < 0xfb) { 160 put!ubyte(cast(ubyte)x); 161 } else if (x <= ushort.max) { 162 put!ubyte(0xfc); 163 put!ushort(cast(ushort)x); 164 } else if (x <= (uint.max >> 8)) { 165 put!ubyte(0xfd); 166 put!ubyte(cast(ubyte)(x)); 167 put!ushort(cast(ushort)(x >> 8)); 168 } else { 169 put!ubyte(0xfe); 170 put!uint(cast(uint)x); 171 put!uint(cast(uint)(x >> 32)); 172 } 173 } 174 175 size_t marker(T)() if (!isArray!T) { 176 grow(offset_, T.sizeof); 177 178 auto place = offset_; 179 offset_ += T.sizeof; 180 return place; 181 } 182 183 size_t marker(T)(size_t count) if (isArray!T) { 184 alias ValueType = Unqual!(typeof(T.init[0])); 185 grow(offset_, ValueType.sizeof * x.length); 186 187 auto place = offset_; 188 offset_ += (ValueType.sizeof * x.length); 189 return place; 190 } 191 192 void finalize(ubyte seq) { 193 if (offset_ >= 0xffffff) 194 throw new MySQLConnectionException("Packet size exceeds 2^24"); 195 uint length = cast(uint)offset_; 196 uint header = cast(uint)((offset_ & 0xffffff) | (seq << 24)); 197 *(cast(uint*)buffer_.ptr) = header; 198 } 199 200 void finalize(ubyte seq, size_t extra) { 201 if (offset_ + extra >= 0xffffff) 202 throw new MySQLConnectionException("Packet size exceeds 2^24"); 203 uint length = cast(uint)(offset_ + extra); 204 uint header = cast(uint)((length & 0xffffff) | (seq << 24)); 205 *(cast(uint*)buffer_.ptr) = header; 206 } 207 208 void reset() { 209 offset_ = 0; 210 } 211 void reserve(size_t size) { 212 (*buffer_).length = max((*buffer_).length, 4 + size); 213 out_ = buffer_.ptr + 4; 214 } 215 216 void fill(ubyte x, size_t size) { 217 grow(offset_, size); 218 out_[offset_..offset_ + size] = 0; 219 offset_ += size; 220 } 221 222 size_t length() const { 223 return offset_; 224 } 225 226 bool empty() const { 227 return offset_ == 0; 228 } 229 230 const(ubyte)[] get() const { 231 return (*buffer_)[0..4 + offset_]; 232 } 233 protected: 234 void grow(size_t offset, size_t size) { 235 auto requested = 4 + offset + size; 236 if (requested > buffer_.length) { 237 auto capacity = (*buffer_).capacity; 238 while (capacity < requested) 239 capacity <<= 1; 240 buffer_.length = requested; 241 out_ = buffer_.ptr + 4; 242 } 243 } 244 ubyte[]* buffer_; 245 ubyte* out_; 246 size_t offset_ = 0; 247 }