1 module database.postgresql.inserter;
2 
3 import std.array;
4 import std.meta;
5 import std.range;
6 import std.string;
7 import std.traits;
8 
9 import database.postgresql.connection;
10 import database.postgresql.exception;
11 
12 enum OnDuplicate : size_t
13 {
14     Ignore,
15     Error,
16     Replace,
17     Update,
18     UpdateAll,
19 }
20 
21 auto inserter(ref Connection connection)
22 {
23     return Inserter(&connection);
24 }
25 
26 auto inserter(Args...)(ref Connection connection, OnDuplicate action, string tableName, Args columns)
27 {
28     auto insert = Inserter(&connection);
29     insert.start(action, tableName, columns);
30     return insert;
31 }
32 
33 auto inserter(Args...)(ref Connection connection, string tableName, Args columns)
34 {
35     auto insert = Inserter(&connection);
36     insert.start(OnDuplicate.Error, tableName, columns);
37     return insert;
38 }
39 
40 private template isSomeStringOrSomeStringArray(T)
41 {
42     enum isSomeStringOrSomeStringArray = isSomeString!(OriginalType!T) || (isArray!T && isSomeString!(ElementType!T));
43 }
44 
45 struct Inserter
46 {
47     @disable this();
48     @disable this(this);
49 
50     this(Connection* connection)
51     {
52         conn_ = connection;
53         pending_ = 0;
54         flushes_ = 0;
55     }
56 
57     ~this()
58     {
59         flush();
60     }
61 
62     void start(Args...)(string tableName, Args fieldNames) if (Args.length && allSatisfy!(isSomeStringOrSomeStringArray, Args))
63     {
64         start(OnDuplicate.Error, tableName, fieldNames);
65     }
66 
67     void start(Args...)(OnDuplicate action, string tableName, Args fieldNames) if (Args.length && allSatisfy!(isSomeStringOrSomeStringArray, Args))
68     {
69         auto fieldCount = fieldNames.length;
70 
71         foreach (size_t i, Arg; Args)
72         {
73             static if (isArray!Arg && !isSomeString!(OriginalType!Arg))
74             {
75                 fieldCount = (fieldCount - 1) + fieldNames[i].length;
76             }
77         }
78 
79         fields_ = fieldCount;
80 
81         Appender!(char[]) app;
82 
83         final switch(action) with (OnDuplicate)
84         {
85             case Ignore:
86                 app.put("insert ignore into ");
87                 break;
88             case Replace:
89                 app.put("replace into ");
90                 break;
91             case UpdateAll:
92                 Appender!(char[]) dupapp;
93 
94                 foreach(size_t i, Arg; Args) {
95                     static if (isSomeString!(OriginalType!Arg))
96                     {
97                         dupapp.put('`');
98                         dupapp.put(fieldNames[i]);
99                         dupapp.put("`=values(`");
100                         dupapp.put(fieldNames[i]);
101                         dupapp.put("`)");
102                     }
103                     else
104                     {
105                         auto columns = fieldNames[i];
106                         foreach (j, name; columns)
107                         {
108                             dupapp.put('`');
109                             dupapp.put(name);
110                             dupapp.put("`=values(`");
111                             dupapp.put(name);
112                             dupapp.put("`)");
113                             if (j + 1 != columns.length)
114                                 dupapp.put(',');
115                         }
116                     }
117                     if (i + 1 != Args.length)
118                         dupapp.put(',');
119                 }
120                 dupUpdate_ = dupapp.data;
121                 goto case Update;
122             case Update:
123             case Error:
124                 app.put("insert into ");
125                 break;
126         }
127 
128         app.put(tableName);
129         app.put('(');
130 
131         foreach (size_t i, Arg; Args)
132         {
133             static if (isSomeString!(OriginalType!Arg))
134             {
135                 fieldsHash_ ~= postgresql.row.hashOf(fieldNames[i]);
136                 fieldsNames_ ~= fieldNames[i];
137 
138                 //app.put('`');
139                 app.put(fieldNames[i]);
140                 //app.put('`');
141             }
142             else
143             {
144                 auto columns = fieldNames[i];
145                 foreach (j, name; columns)
146                 {
147                     fieldsHash_ ~= postgresql.row.hashOf(name);
148                     fieldsNames_ ~= name;
149 
150                     //app.put('`');
151                     app.put(name);
152                     //app.put('`');
153                     if (j + 1 != columns.length)
154                         app.put(',');
155                 }
156             }
157             if (i + 1 != Args.length)
158                 app.put(',');
159         }
160 
161         app.put(")values");
162         start_ = app.data;
163     }
164 
165     auto ref duplicateUpdate(string update)
166     {
167         dupUpdate_ = cast(char[])update;
168         return this;
169     }
170 
171     void rows(T)(ref const T[] param) if (!isValueType!T)
172     {
173         if (param.length < 1)
174             return;
175 
176         foreach (ref p; param)
177             row(p);
178     }
179 
180     private auto tryAppendField(string member, string parentMembers = "", T)(ref const T param, ref size_t fieldHash, ref bool fieldFound)
181     {
182         static if (isReadableDataMember!(Unqual!T, member))
183         {
184             alias memberType = typeof(__traits(getMember, param, member));
185             static if (isValueType!(memberType))
186             {
187                 static if (getUDAs!(__traits(getMember, param, member), NameAttribute).length)
188                 {
189                     enum nameHash = postgresql.row.hashOf(parentMembers~getUDAs!(__traits(getMember, param, member), NameAttribute)[0].name);
190                 }
191                 else
192                 {
193                     enum nameHash = postgresql.row.hashOf(parentMembers~member);
194                 }
195                 if (nameHash == fieldHash || (parentMembers == "" && getUDAs!(T, UnCamelCaseAttribute).length && postgresql.row.hashOf(member.unCamelCase) == fieldHash))
196                 {
197                     appendValue(values_, __traits(getMember, param, member));
198                     fieldFound = true;
199                     return;
200                 }
201             }
202             else
203             {
204                 foreach (subMember; __traits(allMembers, memberType))
205                 {
206                     static if (parentMembers == "")
207                     {
208                         tryAppendField!(subMember, member~".")(__traits(getMember, param, member), fieldHash, fieldFound);
209                     }
210                     else
211                     {
212                         tryAppendField!(subMember, parentMembers~member~".")(__traits(getMember, param, member), fieldHash, fieldFound);
213                     }
214 
215                     if (fieldFound)
216                         return;
217                 }
218             }
219         }
220     }
221 
222     void row (T) (ref const T param) if (!isValueType!T)
223     {
224         scope (failure) reset();
225 
226         if (start_.empty)
227             throw new PgSQLErrorException("Inserter must be initialized with a call to start()");
228 
229         if (!pending_)
230             values_.put(cast(char[])start_);
231 
232         values_.put(pending_ ? ",(" : "(");
233         ++pending_;
234 
235         bool fieldFound;
236         foreach (i, ref fieldHash; fieldsHash_)
237         {
238             fieldFound = false;
239             foreach (member; __traits(allMembers, T))
240             {
241                  tryAppendField!member(param, fieldHash, fieldFound);
242                  if (fieldFound)
243                      break;
244             }
245             if (!fieldFound)
246                 throw new PgSQLErrorException(format("field '%s' was not found in struct => '%s' members", fieldsNames_.ptr[i], typeid(Unqual!T).name));
247 
248             if (i != fields_-1)
249                 values_.put(',');
250         }
251         values_.put(')');
252 
253         if (values_.data.length > (128 << 10)) // todo: make parameter
254             flush();
255 
256         ++rows_;
257     }
258 
259     void row(Values...)(Values values) if(allSatisfy!(isValueType, Values))
260     {
261 
262         scope(failure) reset();
263 
264         if (start_.empty)
265             throw new PgSQLErrorException("Inserter must be initialized with a call to start()");
266 
267         auto valueCount = values.length;
268 
269         foreach (size_t i, Value; Values)
270         {
271             static if (isArray!Value && !isSomeString!(OriginalType!Value))
272             {
273                 valueCount = (valueCount - 1) + values[i].length;
274             }
275         }
276 
277         if (valueCount != fields_)
278             throw new PgSQLErrorException(format("Wrong number of parameters for row. Got %d but expected %d.", valueCount, fields_));
279 
280         if (!pending_)
281             values_.put(cast(char[])start_);
282 
283         values_.put(pending_ ? ",(" : "(");
284         ++pending_;
285         foreach (size_t i, Value; Values)
286         {
287             static if (isArray!Value && !isSomeString!(OriginalType!Value))
288             {
289                 ValueAppender.appendValues(values_, values[i]);
290             }
291             else
292             {
293                 ValueAppender.appendValue(values_, values[i]);
294             }
295             if (i != values.length-1)
296                 values_.put(',');
297         }
298         values_.put(')');
299 
300         if (values_.data.length > (128 << 10)) // todo: make parameter
301             flush();
302 
303         ++rows_;
304     }
305 
306     @property size_t rows() const
307     {
308         return rows_ != 0;
309     }
310 
311     @property size_t pending() const
312     {
313         return pending_ != 0;
314     }
315 
316     @property size_t flushes() const
317     {
318         return flushes_;
319     }
320 
321     private void reset()
322     {
323         values_.clear;
324         pending_ = 0;
325     }
326 
327     void flush()
328     {
329         if (pending_)
330         {
331             if (dupUpdate_.length)
332             {
333                 values_.put(cast(ubyte[])" on duplicate key update ");
334                 values_.put(cast(ubyte[])dupUpdate_);
335             }
336 
337             auto sql = cast(char[])values_.data();
338             reset();
339 
340             conn_.execute(sql);
341             ++flushes_;
342         }
343     }
344 
345 private:
346 
347     char[] start_;
348     char[] dupUpdate_;
349     Appender!(char[]) values_;
350 
351     Connection* conn_;
352     size_t pending_;
353     size_t flushes_;
354     size_t fields_;
355     size_t rows_;
356     string[] fieldsNames_;
357     size_t[] fieldsHash_;
358 }