1 module database.mysql.inserter;
2 
3 import std.array;
4 import std.meta;
5 import std.range;
6 import std.string;
7 import std.traits;
8 import std.typecons;
9 import std.datetime;
10 
11 import database.mysql.connection;
12 import database.mysql.exception;
13 
14 enum OnDuplicate : size_t
15 {
16     Ignore,
17     Error,
18     Replace,
19     Update,
20     UpdateAll,
21 }
22 
23 auto inserter(ref Connection connection)
24 {
25     return Inserter(&connection);
26 }
27 
28 auto inserter(Args...)(ref Connection connection, OnDuplicate action, string tableName, Args columns)
29 {
30     auto insert = Inserter(&connection);
31     insert.start(action, tableName, columns);
32     return insert;
33 }
34 
35 auto inserter(Args...)(ref Connection connection, string tableName, Args columns)
36 {
37     auto insert = Inserter(&connection);
38     insert.start(OnDuplicate.Error, tableName, columns);
39     return insert;
40 }
41 
42 private template isSomeStringOrSomeStringArray(T)
43 {
44     enum isSomeStringOrSomeStringArray = isSomeString!(OriginalType!T) || (isArray!T && isSomeString!(ElementType!T));
45 }
46 
47 struct Inserter
48 {
49     @disable this();
50     @disable this(this);
51 
52     this(Connection* connection)
53     {
54         conn_ = connection;
55         pending_ = 0;
56         flushes_ = 0;
57     }
58 
59     ~this()
60     {
61         flush();
62     }
63 
64     void start(Args...)(string tableName, Args fieldNames) if (Args.length && allSatisfy!(isSomeStringOrSomeStringArray, Args))
65     {
66         start(OnDuplicate.Error, tableName, fieldNames);
67     }
68 
69     void start(Args...)(OnDuplicate action, string tableName, Args fieldNames) if (Args.length && allSatisfy!(isSomeStringOrSomeStringArray, Args))
70     {
71         auto fieldCount = fieldNames.length;
72 
73         foreach (size_t i, Arg; Args)
74         {
75             static if (isArray!Arg && !isSomeString!(OriginalType!Arg))
76             {
77                 fieldCount = (fieldCount - 1) + fieldNames[i].length;
78             }
79         }
80 
81         fields_ = fieldCount;
82 
83         Appender!(char[]) app;
84 
85         final switch(action) with (OnDuplicate)
86         {
87             case Ignore:
88                 app.put("insert ignore into ");
89                 break;
90             case Replace:
91                 app.put("replace into ");
92                 break;
93             case UpdateAll:
94                 Appender!(char[]) dupapp;
95 
96                 foreach(size_t i, Arg; Args)
97                 {
98                     static if (isSomeString!(OriginalType!Arg))
99                     {
100                         dupapp.put('`');
101                         dupapp.put(fieldNames[i]);
102                         dupapp.put("`=values(`");
103                         dupapp.put(fieldNames[i]);
104                         dupapp.put("`)");
105                     }
106                     else
107                     {
108                         auto columns = fieldNames[i];
109                         foreach (j, name; columns)
110                         {
111                             dupapp.put('`');
112                             dupapp.put(name);
113                             dupapp.put("`=values(`");
114                             dupapp.put(name);
115                             dupapp.put("`)");
116                             if (j + 1 != columns.length)
117                                 dupapp.put(',');
118                         }
119                     }
120                     if (i + 1 != Args.length)
121                         dupapp.put(',');
122                 }
123                 dupUpdate_ = dupapp.data;
124                 goto case Update;
125             case Update:
126             case Error:
127                 app.put("insert into ");
128                 break;
129         }
130 
131         app.put(tableName);
132         app.put('(');
133 
134         foreach (size_t i, Arg; Args)
135         {
136             static if (isSomeString!(OriginalType!Arg))
137             {
138                 fieldsHash_ ~= hashOf(fieldNames[i]);
139                 fieldsNames_ ~= fieldNames[i];
140 
141                 app.put('`');
142                 app.put(fieldNames[i]);
143                 app.put('`');
144             }
145             else
146             {
147                 auto columns = fieldNames[i];
148                 foreach (j, name; columns)
149                 {
150                     fieldsHash_ ~= hashOf(name);
151                     fieldsNames_ ~= name;
152 
153                     app.put('`');
154                     app.put(name);
155                     app.put('`');
156                     if (j + 1 != columns.length)
157                         app.put(',');
158                 }
159             }
160             if (i + 1 != Args.length)
161                 app.put(',');
162         }
163 
164         app.put(")values");
165         start_ = app.data;
166     }
167 
168     auto ref duplicateUpdate(string update)
169     {
170         dupUpdate_ = cast(char[])update;
171         return this;
172     }
173 
174     void rows(T)(ref const T[] param) if (!isValueType!T)
175     {
176         if (param.length < 1)
177             return;
178 
179         foreach (ref p; param)
180             row(p);
181     }
182 
183     private auto tryAppendField(string member, string parentMembers = "", T)(ref const T param, ref size_t fieldHash, ref bool fieldFound)
184     {
185         static if (isReadableDataMember!(Unqual!T, member))
186         {
187             alias memberType = typeof(__traits(getMember, param, member));
188             static if (isValueType!(memberType))
189             {
190                 static if (getUDAs!(__traits(getMember, param, member), NameAttribute).length)
191                 {
192                     enum nameHash = hashOf(parentMembers ~ getUDAs!(__traits(getMember, param, member), NameAttribute)[0].name);
193                 }
194                 else
195                 {
196                     enum nameHash = hashOf(parentMembers ~ member);
197                 }
198                 if (nameHash == fieldHash || (parentMembers == "" && getUDAs!(T, UnCamelCaseAttribute).length && hashOf(member.unCamelCase) == fieldHash))
199                 {
200                     appendValue(values_, __traits(getMember, param, member));
201                     fieldFound = true;
202                     return;
203                 }
204             }
205             else
206             {
207                 foreach (subMember; __traits(allMembers, memberType))
208                 {
209                     static if (parentMembers == "")
210                     {
211                         tryAppendField!(subMember, member ~ ".")(__traits(getMember, param, member), fieldHash, fieldFound);
212                     }
213                     else
214                     {
215                         tryAppendField!(subMember, parentMembers ~ member ~ ".")(__traits(getMember, param, member), fieldHash, fieldFound);
216                     }
217 
218                     if (fieldFound)
219                         return;
220                 }
221             }
222         }
223     }
224 
225     void row (T) (ref const T param) if (!isValueType!T)
226     {
227         scope (failure) reset();
228 
229         if (start_.empty)
230             throw new MySQLErrorException("Inserter must be initialized with a call to start()");
231 
232         if (!pending_)
233             values_.put(cast(char[])start_);
234 
235         values_.put(pending_ ? ",(" : "(");
236         ++pending_;
237 
238         bool fieldFound;
239         foreach (i, ref fieldHash; fieldsHash_)
240         {
241             fieldFound = false;
242             foreach (member; __traits(allMembers, T))
243             {
244                  tryAppendField!member(param, fieldHash, fieldFound);
245                  if (fieldFound)
246                      break;
247             }
248             if (!fieldFound)
249                 throw new MySQLErrorException(format("field '%s' was not found in struct => '%s' members", fieldsNames_.ptr[i], typeid(Unqual!T).name));
250 
251             if (i != fields_-1)
252                 values_.put(',');
253         }
254         values_.put(')');
255 
256         if (values_.data.length > (128 << 10)) // todo: make parameter
257             flush();
258 
259         ++rows_;
260     }
261 
262     void row(Values...)(Values values) if(allSatisfy!(isValueType, Values))
263     {
264 
265         scope(failure) reset();
266 
267         if (start_.empty)
268             throw new MySQLErrorException("Inserter must be initialized with a call to start()");
269 
270         auto valueCount = values.length;
271 
272         foreach (size_t i, Value; Values)
273         {
274             static if (isArray!Value && !isSomeString!(OriginalType!Value))
275             {
276                 valueCount = (valueCount - 1) + values[i].length;
277             }
278         }
279 
280         if (valueCount != fields_)
281             throw new MySQLErrorException(format("Wrong number of parameters for row. Got %d but expected %d.", valueCount, fields_));
282 
283         if (!pending_)
284             values_.put(cast(char[])start_);
285 
286         values_.put(pending_ ? ",(" : "(");
287         ++pending_;
288         foreach (size_t i, Value; Values) {
289             static if (isArray!Value && !isSomeString!(OriginalType!Value))
290             {
291                 ValueAppender.appendValues(values_, values[i]);
292             }
293             else
294             {
295                 ValueAppender.appendValue(values_, values[i]);
296             }
297             if (i != values.length-1)
298                 values_.put(',');
299         }
300         values_.put(')');
301 
302         if (values_.data.length > bufferSize_)
303             flush();
304 
305         ++rows_;
306     }
307 
308     @property size_t rows() const
309     {
310         return rows_ != 0;
311     }
312 
313     @property size_t pending() const
314     {
315         return pending_ != 0;
316     }
317 
318     @property size_t flushes() const
319     {
320         return flushes_;
321     }
322 
323     @property void bufferSize(size_t size)
324     {
325         bufferSize_ = size;
326     }
327 
328     @property size_t bufferSize() const
329     {
330         return bufferSize_;
331     }
332 
333     private void reset()
334     {
335         values_.clear;
336         pending_ = 0;
337     }
338 
339     void flush()
340     {
341         if (pending_)
342         {
343             if (dupUpdate_.length)
344             {
345                 values_.put(cast(ubyte[])" on duplicate key update ");
346                 values_.put(cast(ubyte[])dupUpdate_);
347             }
348 
349             auto sql = cast(char[])values_.data();
350             reset();
351 
352             conn_.execute(sql);
353             ++flushes_;
354         }
355     }
356 
357 private:
358 
359     char[] start_;
360     char[] dupUpdate_;
361     Appender!(char[]) values_;
362 
363     Connection* conn_;
364     size_t pending_;
365     size_t flushes_;
366     size_t fields_;
367     size_t rows_;
368     string[] fieldsNames_;
369     size_t[] fieldsHash_;
370     size_t bufferSize_ = (128 << 10);
371 }