1 module mysql.inserter;
2 
3 
4 import std.array;
5 import std.meta;
6 import std.range;
7 import std..string;
8 import std.traits;
9 
10 
11 import mysql.appender;
12 import mysql.exception;
13 import mysql.type;
14 
15 enum OnDuplicate : size_t {
16 	Ignore,
17 	Error,
18 	Replace,
19 	Update,
20 	UpdateAll,
21 }
22 
23 auto inserter(ConnectionType)(auto ref ConnectionType connection) {
24 	return Inserter!ConnectionType(connection);
25 }
26 
27 
28 auto inserter(ConnectionType, Args...)(auto ref ConnectionType connection, OnDuplicate action, string tableName, Args columns) {
29 	auto insert = Inserter!ConnectionType(&connection);
30 	insert.start(action, tableName, columns);
31 	return insert;
32 }
33 
34 
35 auto inserter(ConnectionType, Args...)(auto ref ConnectionType connection, string tableName, Args columns) {
36 	auto insert = Inserter!ConnectionType(&connection);
37 	insert.start(OnDuplicate.Error, tableName, columns);
38 	return insert;
39 }
40 
41 
42 private template isSomeStringOrSomeStringArray(T) {
43 	enum isSomeStringOrSomeStringArray = isSomeString!(OriginalType!T) || (isArray!T && isSomeString!(ElementType!T));
44 }
45 
46 
47 struct Inserter(ConnectionType) {
48 	@disable this();
49 	@disable this(this);
50 
51 	this(ConnectionType* connection) {
52 		conn_ = connection;
53 		pending_ = 0;
54 		flushes_ = 0;
55 	}
56 
57 	~this() {
58 		flush();
59 	}
60 
61 	void start(Args...)(string tableName, Args fieldNames) if (Args.length && allSatisfy!(isSomeStringOrSomeStringArray, Args)) {
62 		start(OnDuplicate.Error, tableName, fieldNames);
63 	}
64 
65 	void start(Args...)(OnDuplicate action, string tableName, Args fieldNames) if (Args.length && allSatisfy!(isSomeStringOrSomeStringArray, Args)) {
66 		auto fieldCount = fieldNames.length;
67 
68 		foreach (size_t i, Arg; Args) {
69 			static if (isArray!Arg && !isSomeString!(OriginalType!Arg)) {
70 				fieldCount = (fieldCount - 1) + fieldNames[i].length;
71 			}
72 		}
73 
74 		fields_ = fieldCount;
75 
76 		Appender!(char[]) app;
77 
78 		final switch(action) with (OnDuplicate) {
79 		case Ignore:
80 			app.put("insert ignore into ");
81 			break;
82 		case Replace:
83 			app.put("replace into ");
84 			break;
85 		case UpdateAll:
86 			Appender!(char[]) dupapp;
87 
88 			foreach(size_t i, Arg; Args) {
89 				static if (isSomeString!(OriginalType!Arg)) {
90 					dupapp.put('`');
91 					dupapp.put(fieldNames[i]);
92 					dupapp.put("`=values(`");
93 					dupapp.put(fieldNames[i]);
94 					dupapp.put("`)");
95 				} else {
96 					auto columns = fieldNames[i];
97 					foreach (j, name; columns) {
98 						dupapp.put('`');
99 						dupapp.put(name);
100 						dupapp.put("`=values(`");
101 						dupapp.put(name);
102 						dupapp.put("`)");
103 						if (j + 1 != columns.length)
104 							dupapp.put(',');
105 					}
106 				}
107 				if (i + 1 != Args.length)
108 					dupapp.put(',');
109 			}
110 			dupUpdate_ = dupapp.data;
111 			goto case Update;
112 		case Update:
113 		case Error:
114 			app.put("insert into ");
115 			break;
116 		}
117 
118 		app.put(tableName);
119 		app.put('(');
120 
121 		foreach (size_t i, Arg; Args) {
122 			static if (isSomeString!(OriginalType!Arg)) {
123 				fieldsHash_ ~= hashOf(fieldNames[i]);
124 				fieldsNames_ ~= fieldNames[i];
125 
126 				app.put('`');
127 				app.put(fieldNames[i]);
128 				app.put('`');
129 			} else {
130 				auto columns = fieldNames[i];
131 				foreach (j, name; columns) {
132 
133 					fieldsHash_ ~= hashOf(name);
134 					fieldsNames_ ~= name;
135 
136 					app.put('`');
137 					app.put(name);
138 					app.put('`');
139 					if (j + 1 != columns.length)
140 						app.put(',');
141 				}
142 			}
143 			if (i + 1 != Args.length)
144 				app.put(',');
145 		}
146 
147 		app.put(")values");
148 		start_ = app.data;
149 	}
150 
151 	auto ref duplicateUpdate(string update) {
152 		dupUpdate_ = cast(char[])update;
153 		return this;
154 	}
155 
156 	void rows(T)(ref const T[] param) if (!isValueType!T) {
157 		if (param.length < 1)
158 			return;
159 
160 		foreach (ref p; param)
161 			row(p);
162 	}
163 
164 	private auto tryAppendField(string member, string parentMembers = "", T)(ref const T param, ref size_t fieldHash, ref bool fieldFound) {
165 		static if (isReadableDataMember!(Unqual!T, member)) {
166 			alias memberType = typeof(__traits(getMember, param, member));
167 			static if (isValueType!(memberType)) {
168 				static if (getUDAs!(__traits(getMember, param, member), NameAttribute).length){
169 					enum nameHash = hashOf(parentMembers~getUDAs!(__traits(getMember, param, member), NameAttribute)[0].name);
170 				}
171 				else {
172 					enum nameHash = hashOf(parentMembers~member);
173 				}
174 				if (nameHash == fieldHash || (parentMembers == "" && getUDAs!(T, UnCamelCaseAttribute).length && hashOf(member.unCamelCase) == fieldHash)) {
175 					appendValue(values_, __traits(getMember, param, member));
176 					fieldFound = true;
177 					return;
178 				}
179 			} else {
180 				foreach (subMember; __traits(allMembers, memberType)) {
181 					static if (parentMembers == "") {
182 						tryAppendField!(subMember, member~".")(__traits(getMember, param, member), fieldHash, fieldFound);
183 					} else {
184 						tryAppendField!(subMember, parentMembers~member~".")(__traits(getMember, param, member), fieldHash, fieldFound);
185 					}
186 
187 					if (fieldFound)
188 						return;
189 				}
190 			}
191 		}
192 	}
193 
194 	void row (T) (ref const T param) if (!isValueType!T) {
195 		scope (failure) reset();
196 
197 		if (start_.empty)
198 			throw new MySQLErrorException("Inserter must be initialized with a call to start()");
199 
200 		if (!pending_)
201 			values_.put(cast(char[])start_);
202 
203 		values_.put(pending_ ? ",(" : "(");
204 		++pending_;
205 
206 		bool fieldFound;
207 		foreach (i, ref fieldHash; fieldsHash_) {
208 			fieldFound = false;
209 			foreach (member; __traits(allMembers, T)) {
210 				 tryAppendField!member(param, fieldHash, fieldFound);
211 				 if (fieldFound)
212 				 	break;
213 			}
214 			if (!fieldFound)
215 				throw new MySQLErrorException(format("field '%s' was not found in struct => '%s' members", fieldsNames_.ptr[i], typeid(Unqual!T).name));
216 
217 			if (i != fields_-1)
218 				values_.put(',');
219 		}
220 		values_.put(')');
221 
222 		if (values_.data.length > (128 << 10)) // todo: make parameter
223 			flush();
224 
225 		++rows_;
226 	}
227 
228 	void row(Values...)(Values values) if(allSatisfy!(isValueType, Values)) {
229 
230 		scope(failure) reset();
231 
232 		if (start_.empty)
233 			throw new MySQLErrorException("Inserter must be initialized with a call to start()");
234 
235 		auto valueCount = values.length;
236 
237 		foreach (size_t i, Value; Values) {
238 			static if (isArray!Value && !isSomeString!(OriginalType!Value)) {
239 				valueCount = (valueCount - 1) + values[i].length;
240 			}
241 		}
242 
243 		if (valueCount != fields_)
244 			throw new MySQLErrorException(format("Wrong number of parameters for row. Got %d but expected %d.", valueCount, fields_));
245 
246 		if (!pending_)
247 			values_.put(cast(char[])start_);
248 
249 		values_.put(pending_ ? ",(" : "(");
250 		++pending_;
251 		foreach (size_t i, Value; Values) {
252 			static if (isArray!Value && !isSomeString!(OriginalType!Value)) {
253 				appendValues(values_, values[i]);
254 			} else {
255 				appendValue(values_, values[i]);
256 			}
257 			if (i != values.length-1)
258 				values_.put(',');
259 		}
260 		values_.put(')');
261 
262 		if (values_.data.length > bufferSize_)
263 			flush();
264 
265 		++rows_;
266 	}
267 
268 
269 
270 	@property size_t rows() const {
271 		return rows_ != 0;
272 	}
273 
274 	@property size_t pending() const {
275 		return pending_ != 0;
276 	}
277 
278 	@property size_t flushes() const {
279 		return flushes_;
280 	}
281 
282 	@property void bufferSize(size_t size) {
283 		bufferSize_ = size;
284 	}
285 
286 	@property size_t bufferSize() const {
287 		return bufferSize_;
288 	}
289 
290 	private void reset(){
291 		values_.clear;
292 		pending_ = 0;
293 	}
294 
295 
296 	void flush() {
297 		if (pending_) {
298 			if (dupUpdate_.length) {
299 				values_.put(cast(ubyte[])" on duplicate key update ");
300 				values_.put(cast(ubyte[])dupUpdate_);
301 			}
302 
303 			auto sql = cast(char[])values_.data();
304 			reset();
305 
306 			conn_.execute(sql);
307 			++flushes_;
308 		}
309 	}
310 
311 private:
312 	char[] start_;
313 	char[] dupUpdate_;
314 	Appender!(char[]) values_;
315 
316 	ConnectionType* conn_;
317 	size_t pending_;
318 	size_t flushes_;
319 	size_t fields_;
320 	size_t rows_;
321 	string[] fieldsNames_;
322 	size_t[] fieldsHash_;
323 	size_t bufferSize_ = (128 << 10);
324 }