1 module mysql.packet;
2 
3 
4 import std.algorithm;
5 import std.traits;
6 
7 import mysql.exception;
8 
9 
10 struct InputPacket {
11     @disable this();
12 
13     this(ubyte[]* buffer) {
14         buffer_ = buffer;
15         in_ = *buffer_;
16     }
17 
18     T peek(T)() if (!isArray!T) {
19         assert(T.sizeof <= in_.length);
20         return *(cast(T*)in_.ptr);
21     }
22 
23     T eat(T)() if (!isArray!T) {
24         assert(T.sizeof <= in_.length);
25         auto ptr = cast(T*)in_.ptr;
26         in_ = in_[T.sizeof..$];
27         return *ptr;
28     }
29 
30     T peek(T)(size_t count) if (isArray!T) {
31         alias ValueType = typeof(Type.init[0]);
32 
33         assert(ValueType.sizeof * count <= in_.length);
34         auto ptr = cast(ValueType*)in_.ptr;
35         return ptr[0..count];
36     }
37 
38     T eat(T)(size_t count) if (isArray!T) {
39         alias ValueType = typeof(T.init[0]);
40 
41         assert(ValueType.sizeof * count <= in_.length);
42         auto ptr = cast(ValueType*)in_.ptr;
43         in_ = in_[ValueType.sizeof * count..$];
44         return ptr[0..count];
45     }
46 
47     void expect(T)(T x) {
48         if (x != eat!T)
49             throw new MySQLProtocolException("Bad packet format");
50     }
51 
52     void skip(size_t count) {
53         assert(count <= in_.length);
54         in_ = in_[count..$];
55     }
56 
57     auto countUntil(ubyte x, bool expect) {
58         auto index = in_.countUntil(x);
59         if (expect) {
60             if ((index < 0) || (in_[index] != x))
61                 throw new MySQLProtocolException("Bad packet format");
62         }
63         return index;
64     }
65 
66     ulong eatLenEnc() {
67         auto header = eat!ubyte;
68         if (header < 0xfb)
69             return header;
70 
71         ulong lo;
72         ulong hi;
73 
74         switch(header) {
75         case 0xfc:
76             return eat!ushort;
77         case 0xfd:
78             lo = eat!ubyte;
79             hi = eat!ushort;
80             return lo | (hi << 8);
81         case 0xfe:
82             lo = eat!uint;
83             hi = eat!uint;
84             return lo | (hi << 32);
85         default:
86             throw new MySQLProtocolException("Bad packet format");
87         }
88     }
89 
90     auto remaining() const {
91         return in_.length;
92     }
93 
94     bool empty() const {
95         return in_.length == 0;
96     }
97 protected:
98     ubyte[]* buffer_;
99     ubyte[] in_;
100 }
101 
102 
103 struct OutputPacket {
104     @disable this();
105 
106     this(ubyte[]* buffer) {
107         buffer_ = buffer;
108         out_ = buffer_.ptr + 4;
109     }
110 
111     void put(T)(T x) if (!isArray!T) {
112         put(offset_, x);
113     }
114 
115     void put(T)(T x) if (isArray!T) {
116         put(offset_, x);
117     }
118 
119     void put(T)(size_t offset, T x) if (!isArray!T) {
120         grow(offset, T.sizeof);
121 
122         *(cast(T*)(out_ + offset)) = x;
123         offset_ = max(offset + T.sizeof, offset_);
124     }
125 
126     void put(T)(size_t offset, T x) if (isArray!T) {
127         alias ValueType = Unqual!(typeof(T.init[0]));
128 
129         grow(offset, ValueType.sizeof * x.length);
130 
131         (cast(ValueType*)(out_ + offset))[0..x.length] = x;
132         offset_ = max(offset + (ValueType.sizeof * x.length), offset_);
133     }
134 
135     void putLenEnc(ulong x) {
136         if (x < 0xfb) {
137             put!ubyte(cast(ubyte)x);
138         } else if (x <= ushort.max) {
139             put!ubyte(0xfc);
140             put!ushort(cast(ushort)x);
141         } else if (x <= (uint.max >> 8)) {
142             put!ubyte(0xfd);
143             put!ubyte(cast(ubyte)(x));
144             put!ushort(cast(ushort)(x >> 8));
145         } else {
146             put!ubyte(0xfe);
147             put!uint(cast(uint)x);
148             put!uint(cast(uint)(x >> 32));
149         }
150     }
151 
152     size_t marker(T)() if (!isArray!T) {
153         grow(offset_, T.sizeof);
154 
155         auto place = offset_;
156         offset_ += T.sizeof;
157         return place;
158     }
159 
160     size_t marker(T)(size_t count) if (isArray!T) {
161         alias ValueType = Unqual!(typeof(T.init[0]));
162         grow(offset_, ValueType.sizeof * x.length);
163 
164         auto place = offset_;
165         offset_ += (ValueType.sizeof * x.length);
166         return place;
167     }
168 
169     void finalize(ubyte seq) {
170         if (offset_ >=  0xffffff)
171             throw new MySQLConnectionException("Packet size exceeds 2^24");
172         uint length = cast(uint)offset_;
173         uint header = cast(uint)((offset_ & 0xffffff) | (seq << 24));
174         *(cast(uint*)buffer_.ptr) = header;
175     }
176 
177     void finalize(ubyte seq, size_t extra) {
178         if (offset_ + extra >= 0xffffff)
179             throw new MySQLConnectionException("Packet size exceeds 2^24");
180         uint length = cast(uint)(offset_ + extra);
181         uint header = cast(uint)((length & 0xffffff) | (seq << 24));
182         *(cast(uint*)buffer_.ptr) = header;
183     }
184 
185     void reset() {
186         offset_ = 0;
187     }
188     void reserve(size_t size) {
189         (*buffer_).length = max((*buffer_).length, 4 + size);
190         out_ = buffer_.ptr + 4;
191     }
192 
193     void fill(ubyte x, size_t size) {
194         grow(offset_, size);
195         out_[offset_..offset_ + size] = 0;
196         offset_ += size;
197     }
198 
199     size_t length() const {
200         return offset_;
201     }
202 
203     bool empty() const {
204         return offset_ == 0;
205     }
206 
207     const(ubyte)[] get() const {
208         return (*buffer_)[0..4 + offset_];
209     }
210 protected:
211     void grow(size_t offset, size_t size) {
212         auto requested = 4 + offset + size;
213         if (requested > buffer_.length) {
214             auto capacity = (*buffer_).capacity;
215             while (capacity < requested)
216                 capacity <<= 1;
217             buffer_.length = requested;
218             out_ = buffer_.ptr + 4;
219         }
220     }
221     ubyte[]* buffer_;
222     ubyte* out_;
223     size_t offset_ = 0;
224 }