1 module mysql.packet;
2 
3 
4 import std.algorithm;
5 import std.traits;
6 
7 import mysql.exception;
8 
9 
10 struct InputPacket {
11 	@disable this();
12 
13 	this(ubyte[]* buffer) {
14 		buffer_ = buffer;
15 		in_ = *buffer_;
16 	}
17 
18 	T peek(T)() if (!isArray!T) {
19 		assert(T.sizeof <= in_.length);
20 		return *(cast(T*)in_.ptr);
21 	}
22 
23 	T eat(T)() if (!isArray!T) {
24 		assert(T.sizeof <= in_.length);
25 		auto ptr = cast(T*)in_.ptr;
26 		in_ = in_[T.sizeof..$];
27 		return *ptr;
28 	}
29 
30 	T peek(T)(size_t count) if (isArray!T) {
31 		alias ValueType = typeof(Type.init[0]);
32 
33 		assert(ValueType.sizeof * count <= in_.length);
34 		auto ptr = cast(ValueType*)in_.ptr;
35 		return ptr[0..count];
36 	}
37 
38 	T eat(T)(size_t count) if (isArray!T) {
39 		alias ValueType = typeof(T.init[0]);
40 
41 		assert(ValueType.sizeof * count <= in_.length);
42 		auto ptr = cast(ValueType*)in_.ptr;
43 		in_ = in_[ValueType.sizeof * count..$];
44 		return ptr[0..count];
45 	}
46 
47 	void expect(T)(T x) {
48 		if (x != eat!T)
49 			throw new MySQLProtocolException("Bad packet format");
50 	}
51 
52 	void skip(size_t count) {
53 		assert(count <= in_.length);
54 		in_ = in_[count..$];
55 	}
56 
57 	auto countUntil(ubyte x, bool expect) {
58 		auto index = in_.countUntil(x);
59 		if (expect) {
60 			if ((index < 0) || (in_[index] != x))
61 				throw new MySQLProtocolException("Bad packet format");
62 		}
63 		return index;
64 	}
65 
66 	void skipLenEnc() {
67 		auto header = eat!ubyte;
68 		if (header >= 0xfb) {
69 			switch(header) {
70 				case 0xfb:
71 					return;
72 				case 0xfc:
73 					skip(2);
74 					return;
75 				case 0xfd:
76 					skip(3);
77 					return;
78 				case 0xfe:
79 					skip(8);
80 					return;
81 				default:
82 					throw new MySQLProtocolException("Bad packet format");
83 			}
84 		}
85 	}
86 
87 	ulong eatLenEnc() {
88 		auto header = eat!ubyte;
89 		if (header < 0xfb)
90 			return header;
91 
92 		ulong lo;
93 		ulong hi;
94 
95 		switch(header) {
96 		case 0xfb:
97 			return 0;
98 		case 0xfc:
99 			return eat!ushort;
100 		case 0xfd:
101 			lo = eat!ubyte;
102 			hi = eat!ushort;
103 			return lo | (hi << 8);
104 		case 0xfe:
105 			lo = eat!uint;
106 			hi = eat!uint;
107 			return lo | (hi << 32);
108 		default:
109 			throw new MySQLProtocolException("Bad packet format");
110 		}
111 	}
112 
113 	auto remaining() const {
114 		return in_.length;
115 	}
116 
117 	bool empty() const {
118 		return in_.length == 0;
119 	}
120 protected:
121 	ubyte[]* buffer_;
122 	ubyte[] in_;
123 }
124 
125 
126 struct OutputPacket {
127 	@disable this();
128 
129 	this(ubyte[]* buffer) {
130 		buffer_ = buffer;
131 		out_ = buffer_.ptr + 4;
132 	}
133 
134 	void put(T)(T x) if (!isArray!T) {
135 		put(offset_, x);
136 	}
137 
138 	void put(T)(T x) if (isArray!T) {
139 		put(offset_, x);
140 	}
141 
142 	void put(T)(size_t offset, T x) if (!isArray!T) {
143 		grow(offset, T.sizeof);
144 
145 		*(cast(T*)(out_ + offset)) = x;
146 		offset_ = max(offset + T.sizeof, offset_);
147 	}
148 
149 	void put(T)(size_t offset, T x) if (isArray!T) {
150 		alias ValueType = Unqual!(typeof(T.init[0]));
151 
152 		grow(offset, ValueType.sizeof * x.length);
153 
154 		(cast(ValueType*)(out_ + offset))[0..x.length] = x;
155 		offset_ = max(offset + (ValueType.sizeof * x.length), offset_);
156 	}
157 
158 	void putLenEnc(ulong x) {
159 		if (x < 0xfb) {
160 			put!ubyte(cast(ubyte)x);
161 		} else if (x <= ushort.max) {
162 			put!ubyte(0xfc);
163 			put!ushort(cast(ushort)x);
164 		} else if (x <= (uint.max >> 8)) {
165 			put!ubyte(0xfd);
166 			put!ubyte(cast(ubyte)(x));
167 			put!ushort(cast(ushort)(x >> 8));
168 		} else {
169 			put!ubyte(0xfe);
170 			put!uint(cast(uint)x);
171 			put!uint(cast(uint)(x >> 32));
172 		}
173 	}
174 
175 	size_t marker(T)() if (!isArray!T) {
176 		grow(offset_, T.sizeof);
177 
178 		auto place = offset_;
179 		offset_ += T.sizeof;
180 		return place;
181 	}
182 
183 	size_t marker(T)(size_t count) if (isArray!T) {
184 		alias ValueType = Unqual!(typeof(T.init[0]));
185 		grow(offset_, ValueType.sizeof * x.length);
186 
187 		auto place = offset_;
188 		offset_ += (ValueType.sizeof * x.length);
189 		return place;
190 	}
191 
192 	void finalize(ubyte seq) {
193 		if (offset_ >=  0xffffff)
194 			throw new MySQLConnectionException("Packet size exceeds 2^24");
195 		uint length = cast(uint)offset_;
196 		uint header = cast(uint)((offset_ & 0xffffff) | (seq << 24));
197 		*(cast(uint*)buffer_.ptr) = header;
198 	}
199 
200 	void finalize(ubyte seq, size_t extra) {
201 		if (offset_ + extra >= 0xffffff)
202 			throw new MySQLConnectionException("Packet size exceeds 2^24");
203 		uint length = cast(uint)(offset_ + extra);
204 		uint header = cast(uint)((length & 0xffffff) | (seq << 24));
205 		*(cast(uint*)buffer_.ptr) = header;
206 	}
207 
208 	void reset() {
209 		offset_ = 0;
210 	}
211 	void reserve(size_t size) {
212 		(*buffer_).length = max((*buffer_).length, 4 + size);
213 		out_ = buffer_.ptr + 4;
214 	}
215 
216 	void fill(ubyte x, size_t size) {
217 		grow(offset_, size);
218 		out_[offset_..offset_ + size] = 0;
219 		offset_ += size;
220 	}
221 
222 	size_t length() const {
223 		return offset_;
224 	}
225 
226 	bool empty() const {
227 		return offset_ == 0;
228 	}
229 
230 	const(ubyte)[] get() const {
231 		return (*buffer_)[0..4 + offset_];
232 	}
233 protected:
234 	void grow(size_t offset, size_t size) {
235 		auto requested = 4 + offset + size;
236 		if (requested > buffer_.length) {
237 			auto capacity = (*buffer_).capacity;
238 			while (capacity < requested)
239 				capacity <<= 1;
240 			buffer_.length = requested;
241 			out_ = buffer_.ptr + 4;
242 		}
243 	}
244 	ubyte[]* buffer_;
245 	ubyte* out_;
246 	size_t offset_ = 0;
247 }