1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::{
10
    super::unit::{
11
        Cursor, ListOptions, ListQueryCond, QueryCond, SortKey, Unit, UnitModel, UpdateQueryCond,
12
        Updates,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    unit_id: String,
34
    code: String,
35
    /// i64 as time tick from Epoch in milliseconds.
36
    created_at: i64,
37
    /// i64 as time tick from Epoch in milliseconds.
38
    modified_at: i64,
39
    owner_id: String,
40
    /// Space-separated value such as `member_id1 member_id2`.
41
    member_ids: String,
42
    name: String,
43
    info: String,
44
}
45

            
46
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
47
#[derive(sqlx::FromRow)]
48
struct CountSchema {
49
    #[sqlx(rename = "COUNT(*)")]
50
    count: i64,
51
}
52

            
53
const TABLE_NAME: &'static str = "unit";
54
const FIELDS: &'static [&'static str] = &[
55
    "unit_id",
56
    "code",
57
    "created_at",
58
    "modified_at",
59
    "owner_id",
60
    "member_ids",
61
    "name",
62
    "info",
63
];
64
const TABLE_INIT_SQL: &'static str = "\
65
    CREATE TABLE IF NOT EXISTS unit (\
66
    unit_id TEXT NOT NULL UNIQUE,\
67
    code TEXT NOT NULL UNIQUE,\
68
    created_at INTEGER NOT NULL,\
69
    modified_at INTEGER NOT NULL,\
70
    owner_id TEXT NOT NULL,\
71
    member_ids TEXT NOT NULL,\
72
    name TEXT NOT NULL,\
73
    info TEXT,\
74
    PRIMARY KEY (unit_id))";
75

            
76
impl Model {
77
    /// To create the model instance with a database connection.
78
56
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
79
56
        let model = Model { conn };
80
56
        model.init().await?;
81
56
        Ok(model)
82
56
    }
83
}
84

            
85
#[async_trait]
86
impl UnitModel for Model {
87
96
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
88
96
        let _ = sqlx::query(TABLE_INIT_SQL)
89
96
            .execute(self.conn.as_ref())
90
96
            .await?;
91
96
        Ok(())
92
192
    }
93

            
94
86
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
95
86
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
96

            
97
86
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
98
86
            .fetch_one(self.conn.as_ref())
99
86
            .await;
100

            
101
86
        let row = match result {
102
            Err(e) => return Err(Box::new(e)),
103
86
            Ok(row) => row,
104
86
        };
105
86
        Ok(row.count as u64)
106
172
    }
107

            
108
    async fn list(
109
        &self,
110
        opts: &ListOptions,
111
        cursor: Option<Box<dyn Cursor>>,
112
252
    ) -> Result<(Vec<Unit>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
113
252
        let mut cursor = match cursor {
114
216
            None => Box::new(DbCursor::new()),
115
36
            Some(cursor) => cursor,
116
        };
117

            
118
252
        let mut opts = ListOptions { ..*opts };
119
252
        if let Some(offset) = opts.offset {
120
48
            opts.offset = Some(offset + cursor.offset());
121
204
        } else {
122
204
            opts.offset = Some(cursor.offset());
123
204
        }
124
252
        let opts_limit = opts.limit;
125
252
        if let Some(limit) = opts_limit {
126
160
            if limit > 0 {
127
158
                if cursor.offset() >= limit {
128
10
                    return Ok((vec![], None));
129
148
                }
130
148
                opts.limit = Some(limit - cursor.offset());
131
2
            }
132
92
        }
133
242
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
134
242
        build_limit_offset(&mut builder, &opts);
135
242
        build_sort(&mut builder, &opts);
136
242
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
137

            
138
242
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
139
242

            
140
242
        let mut count: u64 = 0;
141
242
        let mut list = vec![];
142
3948
        while let Some(row) = rows.try_next().await? {
143
3762
            let _ = cursor.as_mut().try_next().await?;
144
3762
            let member_ids = row
145
3762
                .member_ids
146
3762
                .split(" ")
147
3868
                .filter_map(|x| {
148
3868
                    if x.len() > 0 {
149
3712
                        Some(x.to_string())
150
                    } else {
151
156
                        None
152
                    }
153
3868
                })
154
3762
                .collect();
155
3762
            list.push(Unit {
156
3762
                unit_id: row.unit_id,
157
3762
                code: row.code,
158
3762
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
159
3762
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
160
3762
                owner_id: row.owner_id,
161
3762
                member_ids,
162
3762
                name: row.name,
163
3762
                info: serde_json::from_str(row.info.as_str())?,
164
            });
165
3762
            if let Some(limit) = opts_limit {
166
2006
                if limit > 0 && cursor.offset() >= limit {
167
30
                    if let Some(cursor_max) = opts.cursor_max {
168
28
                        if (count + 1) >= cursor_max {
169
10
                            return Ok((list, Some(cursor)));
170
18
                        }
171
2
                    }
172
20
                    return Ok((list, None));
173
1976
                }
174
1756
            }
175
3732
            if let Some(cursor_max) = opts.cursor_max {
176
3566
                count += 1;
177
3566
                if count >= cursor_max {
178
26
                    return Ok((list, Some(cursor)));
179
3540
                }
180
166
            }
181
        }
182
186
        Ok((list, None))
183
504
    }
184

            
185
2630
    async fn get(&self, cond: &QueryCond) -> Result<Option<Unit>, Box<dyn StdError>> {
186
2630
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
187

            
188
2630
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
189
2630
            .fetch_one(self.conn.as_ref())
190
2630
            .await;
191

            
192
2630
        let row = match result {
193
472
            Err(e) => match e {
194
472
                sqlx::Error::RowNotFound => return Ok(None),
195
                _ => return Err(Box::new(e)),
196
            },
197
2158
            Ok(row) => row,
198
2158
        };
199
2158

            
200
2158
        let member_ids = row
201
2158
            .member_ids
202
2158
            .split(" ")
203
2908
            .filter_map(|x| {
204
2908
                if x.len() > 0 {
205
2892
                    Some(x.to_string())
206
                } else {
207
16
                    None
208
                }
209
2908
            })
210
2158
            .collect();
211
2158
        Ok(Some(Unit {
212
2158
            unit_id: row.unit_id,
213
2158
            code: row.code,
214
2158
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
215
2158
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
216
2158
            owner_id: row.owner_id,
217
2158
            member_ids,
218
2158
            name: row.name,
219
2158
            info: serde_json::from_str(row.info.as_str())?,
220
        }))
221
5260
    }
222

            
223
2718
    async fn add(&self, unit: &Unit) -> Result<(), Box<dyn StdError>> {
224
2718
        let info = match serde_json::to_string(&unit.info) {
225
            Err(_) => quote("{}"),
226
2718
            Ok(value) => quote(value.as_str()),
227
        };
228
2718
        let values = vec![
229
2718
            quote(unit.unit_id.as_str()),
230
2718
            quote(unit.code.as_str()),
231
2718
            unit.created_at.timestamp_millis().to_string(),
232
2718
            unit.modified_at.timestamp_millis().to_string(),
233
2718
            quote(unit.owner_id.as_str()),
234
2718
            quote(unit.member_ids.join(" ")),
235
2718
            quote(unit.name.as_str()),
236
2718
            info,
237
2718
        ];
238
2718
        let sql = SqlBuilder::insert_into(TABLE_NAME)
239
2718
            .fields(FIELDS)
240
2718
            .values(&values)
241
2718
            .sql()?;
242
2718
        let _ = sqlx::query(sql.as_str())
243
2718
            .execute(self.conn.as_ref())
244
2718
            .await?;
245
2714
        Ok(())
246
5436
    }
247

            
248
38
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
249
38
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
250
38
        let _ = sqlx::query(sql.as_str())
251
38
            .execute(self.conn.as_ref())
252
38
            .await?;
253
38
        Ok(())
254
76
    }
255

            
256
    async fn update(
257
        &self,
258
        cond: &UpdateQueryCond,
259
        updates: &Updates,
260
30
    ) -> Result<(), Box<dyn StdError>> {
261
30
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
262
        {
263
2
            None => return Ok(()),
264
28
            Some(builder) => builder.sql()?,
265
        };
266
28
        let _ = sqlx::query(sql.as_str())
267
28
            .execute(self.conn.as_ref())
268
28
            .await?;
269
28
        Ok(())
270
60
    }
271
}
272

            
273
impl DbCursor {
274
    /// To create the cursor instance.
275
216
    pub fn new() -> Self {
276
216
        DbCursor { offset: 0 }
277
216
    }
278
}
279

            
280
#[async_trait]
281
impl Cursor for DbCursor {
282
3762
    async fn try_next(&mut self) -> Result<Option<Unit>, Box<dyn StdError>> {
283
3762
        self.offset += 1;
284
3762
        Ok(None)
285
7524
    }
286

            
287
2564
    fn offset(&self) -> u64 {
288
2564
        self.offset
289
2564
    }
290
}
291

            
292
/// Transforms query conditions to the SQL builder.
293
2668
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
294
2668
    if let Some(value) = cond.unit_id {
295
2618
        builder.and_where_eq("unit_id", quote(value));
296
2618
    }
297
2668
    if let Some(value) = cond.code {
298
48
        builder.and_where_eq("code", quote(value));
299
2620
    }
300
2668
    if let Some(value) = cond.owner_id {
301
550
        builder.and_where_eq("owner_id", quote(value));
302
2118
    }
303
2668
    if let Some(value) = cond.member_id {
304
868
        // Use LIKE because one ID will not be part of another ID.
305
868
        build_where_like(builder, "member_ids", value);
306
1800
    }
307
2668
    builder
308
2668
}
309

            
310
/// Transforms query conditions to the SQL builder.
311
328
fn build_list_where<'a>(
312
328
    builder: &'a mut SqlBuilder,
313
328
    cond: &ListQueryCond<'a>,
314
328
) -> &'a mut SqlBuilder {
315
328
    if let Some(value) = cond.owner_id {
316
88
        builder.and_where_eq("owner_id", quote(value));
317
240
    }
318
328
    if let Some(value) = cond.member_id {
319
156
        // Use LIKE because one ID will not be part of another ID.
320
156
        build_where_like(builder, "member_ids", value);
321
172
    }
322
328
    if let Some(value) = cond.unit_id {
323
12
        builder.and_where_eq("unit_id", quote(value));
324
316
    }
325
328
    if let Some(value) = cond.code_contains {
326
82
        build_where_like(builder, "code", value.to_lowercase().as_str());
327
246
    }
328
328
    if let Some(value) = cond.name_contains {
329
16
        build_where_like(builder, "name", value.to_lowercase().as_str());
330
312
    }
331
328
    builder
332
328
}
333

            
334
/// Transforms model options to the SQL builder.
335
242
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
336
242
    if let Some(value) = opts.limit {
337
150
        if value > 0 {
338
148
            builder.limit(value);
339
148
        }
340
92
    }
341
242
    if let Some(value) = opts.offset {
342
242
        match opts.limit {
343
92
            None => builder.limit(-1).offset(value),
344
2
            Some(0) => builder.limit(-1).offset(value),
345
148
            _ => builder.offset(value),
346
        };
347
    }
348
242
    builder
349
242
}
350

            
351
/// Transforms model options to the SQL builder.
352
242
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
353
242
    if let Some(sort_cond) = opts.sort.as_ref() {
354
200
        for cond in sort_cond.iter() {
355
200
            let key = match cond.key {
356
16
                SortKey::CreatedAt => "created_at",
357
12
                SortKey::ModifiedAt => "modified_at",
358
156
                SortKey::Code => "code",
359
16
                SortKey::Name => "name",
360
            };
361
200
            builder.order_by(key, !cond.asc);
362
        }
363
44
    }
364
242
    builder
365
242
}
366

            
367
/// Transforms query conditions and the model object to the SQL builder.
368
30
fn build_update_where<'a>(
369
30
    builder: &'a mut SqlBuilder,
370
30
    cond: &UpdateQueryCond<'a>,
371
30
    updates: &Updates,
372
30
) -> Option<&'a mut SqlBuilder> {
373
30
    let mut count = 0;
374
30
    if let Some(value) = updates.modified_at.as_ref() {
375
28
        builder.set("modified_at", value.timestamp_millis());
376
28
        count += 1;
377
28
    }
378
30
    if let Some(value) = updates.owner_id.as_ref() {
379
16
        builder.set("owner_id", quote(value));
380
16
        count += 1;
381
16
    }
382
30
    if let Some(value) = updates.member_ids.as_ref() {
383
16
        builder.set("member_ids", quote(value.join(" ")));
384
16
        count += 1;
385
16
    }
386
30
    if let Some(value) = updates.name.as_ref() {
387
20
        builder.set("name", quote(value));
388
20
        count += 1;
389
20
    }
390
30
    if let Some(value) = updates.info {
391
20
        match serde_json::to_string(value) {
392
            Err(_) => {
393
                builder.set("info", quote("{}"));
394
            }
395
20
            Ok(value) => {
396
20
                builder.set("info", quote(value));
397
20
            }
398
        }
399
20
        count += 1;
400
10
    }
401
30
    if count == 0 {
402
2
        return None;
403
28
    }
404
28

            
405
28
    builder.and_where_eq("unit_id", quote(cond.unit_id));
406
28
    Some(builder)
407
30
}