1 module mysql.packet;
2 
3 
4 import std.algorithm;
5 import std.traits;
6 
7 import mysql.exception;
8 
9 
10 struct InputPacket {
11 	@disable this();
12 
13 	this(ubyte[]* buffer) {
14 		buffer_ = buffer;
15 		in_ = *buffer_;
16 	}
17 
18 	T peek(T)() if (!isArray!T) {
19 		assert(T.sizeof <= in_.length);
20 		return *(cast(T*)in_.ptr);
21 	}
22 
23 	T eat(T)() if (!isArray!T) {
24 		assert(T.sizeof <= in_.length);
25 		auto ptr = cast(T*)in_.ptr;
26 		in_ = in_[T.sizeof..$];
27 		return *ptr;
28 	}
29 
30 	T peek(T)(size_t count) if (isArray!T) {
31 		alias ValueType = typeof(Type.init[0]);
32 
33 		assert(ValueType.sizeof * count <= in_.length);
34 		auto ptr = cast(ValueType*)in_.ptr;
35 		return ptr[0..count];
36 	}
37 
38 	T eat(T)(size_t count) if (isArray!T) {
39 		alias ValueType = typeof(T.init[0]);
40 
41 		assert(ValueType.sizeof * count <= in_.length);
42 		auto ptr = cast(ValueType*)in_.ptr;
43 		in_ = in_[ValueType.sizeof * count..$];
44 		return ptr[0..count];
45 	}
46 
47 	void expect(T)(T x) {
48 		if (x != eat!T)
49 			throw new MySQLProtocolException("Bad packet format");
50 	}
51 
52 	void skip(size_t count) {
53 		assert(count <= in_.length);
54 		in_ = in_[count..$];
55 	}
56 
57 	auto countUntil(ubyte x, bool expect) {
58 		auto index = in_.countUntil(x);
59 		if (expect) {
60 			if ((index < 0) || (in_[index] != x))
61 				throw new MySQLProtocolException("Bad packet format");
62 		}
63 		return index;
64 	}
65 
66 	void skipLenEnc() {
67 		auto header = eat!ubyte;
68 		if (header >= 0xfb) {
69 			switch(header) {
70 				case 0xfb:
71 					return;
72 				case 0xfc:
73 					skip(2);
74 					return;
75 				case 0xfd:
76 					skip(3);
77 					return;
78 				case 0xfe:
79 					skip(8);
80 					return;
81 				default:
82 					throw new MySQLProtocolException("Bad packet format");
83 			}
84 		}
85 	}
86 
87 	ulong eatLenEnc() {
88 		auto header = eat!ubyte;
89 		if (header < 0xfb)
90 			return header;
91 
92 		ulong lo;
93 		ulong hi;
94 
95 		switch(header) {
96 		case 0xfb:
97 			return 0;
98 		case 0xfc:
99 			return eat!ushort;
100 		case 0xfd:
101 			lo = eat!ubyte;
102 			hi = eat!ushort;
103 			return lo | (hi << 8);
104 		case 0xfe:
105 			lo = eat!uint;
106 			hi = eat!uint;
107 			return lo | (hi << 32);
108 		default:
109 			throw new MySQLProtocolException("Bad packet format");
110 		}
111 	}
112 
113 	auto remaining() const {
114 		return in_.length;
115 	}
116 
117 	bool empty() const {
118 		return in_.length == 0;
119 	}
120 protected:
121 	ubyte[]* buffer_;
122 	ubyte[] in_;
123 }
124 
125 
126 struct OutputPacket {
127 	@disable this();
128 
129 	this(ubyte[]* buffer) {
130 		buffer_ = buffer;
131 		out_ = buffer_.ptr + 4;
132 	}
133 
134 	pragma(inline, true) void put(T)(T x) {
135 		put(offset_, x);
136 	}
137 
138 	void put(T)(size_t offset, T x) if (!isArray!T) {
139 		grow(offset, T.sizeof);
140 
141 		*(cast(T*)(out_ + offset)) = x;
142 		offset_ = max(offset + T.sizeof, offset_);
143 	}
144 
145 	void put(T)(size_t offset, T x) if (isArray!T) {
146 		alias ValueType = Unqual!(typeof(T.init[0]));
147 
148 		grow(offset, ValueType.sizeof * x.length);
149 
150 		(cast(ValueType*)(out_ + offset))[0..x.length] = x;
151 		offset_ = max(offset + (ValueType.sizeof * x.length), offset_);
152 	}
153 
154 	void putLenEnc(ulong x) {
155 		if (x < 0xfb) {
156 			put!ubyte(cast(ubyte)x);
157 		} else if (x <= ushort.max) {
158 			put!ubyte(0xfc);
159 			put!ushort(cast(ushort)x);
160 		} else if (x <= (uint.max >> 8)) {
161 			put!ubyte(0xfd);
162 			put!ubyte(cast(ubyte)(x));
163 			put!ushort(cast(ushort)(x >> 8));
164 		} else {
165 			put!ubyte(0xfe);
166 			put!uint(cast(uint)x);
167 			put!uint(cast(uint)(x >> 32));
168 		}
169 	}
170 
171 	size_t marker(T)() if (!isArray!T) {
172 		grow(offset_, T.sizeof);
173 
174 		auto place = offset_;
175 		offset_ += T.sizeof;
176 		return place;
177 	}
178 
179 	size_t marker(T)(size_t count) if (isArray!T) {
180 		alias ValueType = Unqual!(typeof(T.init[0]));
181 		grow(offset_, ValueType.sizeof * x.length);
182 
183 		auto place = offset_;
184 		offset_ += (ValueType.sizeof * x.length);
185 		return place;
186 	}
187 
188 	void finalize(ubyte seq) {
189 		if (offset_ >=  0xffffff)
190 			throw new MySQLConnectionException("Packet size exceeds 2^24");
191 		uint length = cast(uint)offset_;
192 		uint header = cast(uint)((offset_ & 0xffffff) | (seq << 24));
193 		*(cast(uint*)buffer_.ptr) = header;
194 	}
195 
196 	void finalize(ubyte seq, size_t extra) {
197 		if (offset_ + extra >= 0xffffff)
198 			throw new MySQLConnectionException("Packet size exceeds 2^24");
199 		uint length = cast(uint)(offset_ + extra);
200 		uint header = cast(uint)((length & 0xffffff) | (seq << 24));
201 		*(cast(uint*)buffer_.ptr) = header;
202 	}
203 
204 	void reset() {
205 		offset_ = 0;
206 	}
207 
208 	void reserve(size_t size) {
209 		(*buffer_).length = max((*buffer_).length, 4 + size);
210 		out_ = buffer_.ptr + 4;
211 	}
212 
213 	void fill(ubyte x, size_t size) {
214 		grow(offset_, size);
215 		out_[offset_..offset_ + size] = 0;
216 		offset_ += size;
217 	}
218 
219 	size_t length() const {
220 		return offset_;
221 	}
222 
223 	bool empty() const {
224 		return offset_ == 0;
225 	}
226 
227 	const(ubyte)[] get() const {
228 		return (*buffer_)[0..4 + offset_];
229 	}
230 protected:
231 	void grow(size_t offset, size_t size) {
232 		auto requested = 4 + offset + size;
233 		if (requested > buffer_.length) {
234 			auto capacity = max(128, (*buffer_).capacity);
235 			while (capacity < requested)
236 				capacity <<= 1;
237 			buffer_.length = capacity;
238 			out_ = buffer_.ptr + 4;
239 		}
240 	}
241 	ubyte[]* buffer_;
242 	ubyte* out_;
243 	size_t offset_;
244 }