1
use std::{collections::HashMap, error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use serde_json;
7
use sql_builder::{quote, SqlBuilder};
8
use sqlx::SqlitePool;
9

            
10
use super::{
11
    super::user::{
12
        Cursor, ListOptions, ListQueryCond, QueryCond, SortKey, Updates, User, UserModel,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    user_id: String,
34
    account: String,
35
    /// i64 as time tick from Epoch in milliseconds.
36
    created_at: i64,
37
    /// i64 as time tick from Epoch in milliseconds.
38
    modified_at: i64,
39
    /// i64 as time tick from Epoch in milliseconds.
40
    verified_at: Option<i64>,
41
    /// i64 as time tick from Epoch in milliseconds.
42
    expired_at: Option<i64>,
43
    /// i64 as time tick from Epoch in milliseconds.
44
    disabled_at: Option<i64>,
45
    /// JSON string value such as `{"role1":true,"role2":false}`.
46
    roles: String,
47
    password: String,
48
    salt: String,
49
    name: String,
50
    info: String,
51
}
52

            
53
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
54
#[derive(sqlx::FromRow)]
55
struct CountSchema {
56
    #[sqlx(rename = "COUNT(*)")]
57
    count: i64,
58
}
59

            
60
const TABLE_NAME: &'static str = "user";
61
const FIELDS: &'static [&'static str] = &[
62
    "user_id",
63
    "account",
64
    "created_at",
65
    "modified_at",
66
    "verified_at",
67
    "expired_at",
68
    "disabled_at",
69
    "roles",
70
    "password",
71
    "salt",
72
    "name",
73
    "info",
74
];
75
const TABLE_INIT_SQL: &'static str = "\
76
    CREATE TABLE IF NOT EXISTS user (\
77
    user_id TEXT NOT NULL UNIQUE,\
78
    account TEXT NOT NULL UNIQUE,\
79
    created_at INTEGER NOT NULL,\
80
    modified_at INTEGER NOT NULL,\
81
    verified_at INTEGER,\
82
    expired_at INTEGER,\
83
    disabled_at INTEGER,\
84
    roles TEXT,\
85
    password TEXT NOT NULL,\
86
    salt TEXT NOT NULL,\
87
    name TEXT NOT NULL,\
88
    info TEXT,\
89
    PRIMARY KEY (user_id))";
90

            
91
impl Model {
92
    /// To create the model instance with a database connection.
93
11
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
94
11
        let model = Model { conn };
95
11
        model.init().await?;
96
11
        Ok(model)
97
11
    }
98
}
99

            
100
#[async_trait]
101
impl UserModel for Model {
102
20
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
103
20
        let _ = sqlx::query(TABLE_INIT_SQL)
104
20
            .execute(self.conn.as_ref())
105
20
            .await?;
106
20
        Ok(())
107
40
    }
108

            
109
21
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
110
21
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
111

            
112
21
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
113
21
            .fetch_one(self.conn.as_ref())
114
21
            .await;
115

            
116
21
        let row = match result {
117
            Err(e) => {
118
                return Err(Box::new(e));
119
            }
120
21
            Ok(row) => row,
121
21
        };
122
21
        Ok(row.count as u64)
123
42
    }
124

            
125
    async fn list(
126
        &self,
127
        opts: &ListOptions,
128
        cursor: Option<Box<dyn Cursor>>,
129
83
    ) -> Result<(Vec<User>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
130
83
        let mut cursor = match cursor {
131
72
            None => Box::new(DbCursor::new()),
132
11
            Some(cursor) => cursor,
133
        };
134

            
135
83
        let mut opts = ListOptions { ..*opts };
136
83
        if let Some(offset) = opts.offset {
137
14
            opts.offset = Some(offset + cursor.offset());
138
69
        } else {
139
69
            opts.offset = Some(cursor.offset());
140
69
        }
141
83
        let opts_limit = opts.limit;
142
83
        if let Some(limit) = opts_limit {
143
44
            if limit > 0 {
144
43
                if cursor.offset() >= limit {
145
3
                    return Ok((vec![], None));
146
40
                }
147
40
                opts.limit = Some(limit - cursor.offset());
148
1
            }
149
39
        }
150
80
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
151
80
        build_limit_offset(&mut builder, &opts);
152
80
        build_sort(&mut builder, &opts);
153
80
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
154

            
155
80
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
156
80

            
157
80
        let mut count: u64 = 0;
158
80
        let mut list = vec![];
159
1201
        while let Some(row) = rows.try_next().await? {
160
1136
            let _ = cursor.as_mut().try_next().await?;
161
1136
            let roles: HashMap<String, bool> = match serde_json::from_str(row.roles.as_str()) {
162
                Err(_) => HashMap::new(),
163
1136
                Ok(roles) => roles,
164
            };
165
1136
            list.push(User {
166
1136
                user_id: row.user_id,
167
1136
                account: row.account,
168
1136
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
169
1136
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
170
1136
                verified_at: match row.verified_at {
171
117
                    None => None,
172
1019
                    Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
173
                },
174
1136
                expired_at: match row.expired_at {
175
1077
                    None => None,
176
59
                    Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
177
                },
178
1136
                disabled_at: match row.disabled_at {
179
1075
                    None => None,
180
61
                    Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
181
                },
182
1136
                roles,
183
1136
                password: row.password,
184
1136
                salt: row.salt,
185
1136
                name: row.name,
186
1136
                info: serde_json::from_str(row.info.as_str())?,
187
            });
188
1136
            if let Some(limit) = opts_limit {
189
444
                if limit > 0 && cursor.offset() >= limit {
190
7
                    if let Some(cursor_max) = opts.cursor_max {
191
6
                        if (count + 1) >= cursor_max {
192
3
                            return Ok((list, Some(cursor)));
193
3
                        }
194
1
                    }
195
4
                    return Ok((list, None));
196
437
                }
197
692
            }
198
1129
            if let Some(cursor_max) = opts.cursor_max {
199
1026
                count += 1;
200
1026
                if count >= cursor_max {
201
8
                    return Ok((list, Some(cursor)));
202
1018
                }
203
103
            }
204
        }
205
65
        Ok((list, None))
206
166
    }
207

            
208
449
    async fn get(&self, cond: &QueryCond) -> Result<Option<User>, Box<dyn StdError>> {
209
449
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
210

            
211
449
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
212
449
            .fetch_one(self.conn.as_ref())
213
449
            .await;
214

            
215
449
        let row = match result {
216
11
            Err(e) => match e {
217
11
                sqlx::Error::RowNotFound => return Ok(None),
218
                _ => return Err(Box::new(e)),
219
            },
220
438
            Ok(row) => row,
221
        };
222

            
223
438
        let roles: HashMap<String, bool> = match serde_json::from_str(row.roles.as_str()) {
224
            Err(_) => HashMap::new(),
225
438
            Ok(roles) => roles,
226
        };
227
        Ok(Some(User {
228
438
            user_id: row.user_id,
229
438
            account: row.account,
230
438
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
231
438
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
232
438
            verified_at: match row.verified_at {
233
12
                None => None,
234
426
                Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
235
            },
236
438
            expired_at: match row.expired_at {
237
428
                None => None,
238
10
                Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
239
            },
240
438
            disabled_at: match row.disabled_at {
241
427
                None => None,
242
11
                Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
243
            },
244
438
            roles,
245
438
            password: row.password,
246
438
            salt: row.salt,
247
438
            name: row.name,
248
438
            info: serde_json::from_str(row.info.as_str())?,
249
        }))
250
898
    }
251

            
252
549
    async fn add(&self, user: &User) -> Result<(), Box<dyn StdError>> {
253
549
        let roles = match serde_json::to_string(&user.roles) {
254
            Err(_) => quote("{}"),
255
549
            Ok(value) => quote(value.as_str()),
256
        };
257
549
        let info = match serde_json::to_string(&user.info) {
258
            Err(_) => quote("{}"),
259
549
            Ok(value) => quote(value.as_str()),
260
        };
261
549
        let values = vec![
262
549
            quote(user.user_id.as_str()),
263
549
            quote(user.account.to_lowercase().as_str()),
264
549
            user.created_at.timestamp_millis().to_string(),
265
549
            user.modified_at.timestamp_millis().to_string(),
266
549
            match user.verified_at {
267
31
                None => "NULL".to_string(),
268
518
                Some(value) => value.timestamp_millis().to_string(),
269
            },
270
549
            match user.expired_at {
271
538
                None => "NULL".to_string(),
272
11
                Some(value) => value.timestamp_millis().to_string(),
273
            },
274
549
            match user.disabled_at {
275
541
                None => "NULL".to_string(),
276
8
                Some(value) => value.timestamp_millis().to_string(),
277
            },
278
549
            roles,
279
549
            quote(user.password.as_str()),
280
549
            quote(user.salt.as_str()),
281
549
            quote(user.name.as_str()),
282
549
            info,
283
        ];
284
549
        let sql = SqlBuilder::insert_into(TABLE_NAME)
285
549
            .fields(FIELDS)
286
549
            .values(&values)
287
549
            .sql()?;
288
549
        let _ = sqlx::query(sql.as_str())
289
549
            .execute(self.conn.as_ref())
290
549
            .await?;
291
547
        Ok(())
292
1098
    }
293

            
294
5
    async fn del(&self, user_id: &str) -> Result<(), Box<dyn StdError>> {
295
5
        let sql = SqlBuilder::delete_from(TABLE_NAME)
296
5
            .and_where_eq("user_id", quote(user_id))
297
5
            .sql()?;
298
5
        let _ = sqlx::query(sql.as_str())
299
5
            .execute(self.conn.as_ref())
300
5
            .await?;
301
5
        Ok(())
302
10
    }
303

            
304
24
    async fn update(&self, user_id: &str, updates: &Updates) -> Result<(), Box<dyn StdError>> {
305
23
        let sql =
306
24
            match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), user_id, updates) {
307
1
                None => return Ok(()),
308
23
                Some(builder) => builder.sql()?,
309
            };
310
23
        let _ = sqlx::query(sql.as_str())
311
23
            .execute(self.conn.as_ref())
312
23
            .await?;
313
23
        Ok(())
314
48
    }
315
}
316

            
317
impl DbCursor {
318
    /// To create the cursor instance.
319
72
    pub fn new() -> Self {
320
72
        DbCursor { offset: 0 }
321
72
    }
322
}
323

            
324
#[async_trait]
325
impl Cursor for DbCursor {
326
1136
    async fn try_next(&mut self) -> Result<Option<User>, Box<dyn StdError>> {
327
1136
        self.offset += 1;
328
1136
        Ok(None)
329
2272
    }
330

            
331
610
    fn offset(&self) -> u64 {
332
610
        self.offset
333
610
    }
334
}
335

            
336
/// Transforms query conditions to the SQL builder.
337
449
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
338
449
    if let Some(value) = cond.user_id {
339
299
        builder.and_where_eq("user_id", quote(value));
340
299
    }
341
449
    if let Some(value) = cond.account {
342
150
        builder.and_where_eq("account", quote(value.to_lowercase().as_str()));
343
299
    }
344
449
    builder
345
449
}
346

            
347
/// Transforms query conditions to the SQL builder.
348
101
fn build_list_where<'a>(
349
101
    builder: &'a mut SqlBuilder,
350
101
    cond: &ListQueryCond<'a>,
351
101
) -> &'a mut SqlBuilder {
352
101
    if let Some(value) = cond.user_id {
353
2
        builder.and_where_eq("user_id", quote(value));
354
99
    }
355
101
    if let Some(value) = cond.account {
356
11
        builder.and_where_eq("account", quote(value.to_lowercase().as_str()));
357
90
    }
358
101
    if let Some(value) = cond.account_contains {
359
35
        build_where_like(builder, "account", value.to_lowercase().as_str());
360
66
    }
361
101
    if let Some(value) = cond.verified_at {
362
6
        if value {
363
4
            builder.and_where_is_not_null("verified_at");
364
4
        } else {
365
2
            builder.and_where_is_null("verified_at");
366
2
        }
367
95
    }
368
101
    if let Some(value) = cond.disabled_at {
369
4
        if value {
370
2
            builder.and_where_is_not_null("disabled_at");
371
2
        } else {
372
2
            builder.and_where_is_null("disabled_at");
373
2
        }
374
97
    }
375
101
    if let Some(value) = cond.name_contains {
376
6
        build_where_like(builder, "name", value);
377
95
    }
378
101
    builder
379
101
}
380

            
381
/// Transforms model options to the SQL builder.
382
80
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
383
80
    if let Some(value) = opts.limit {
384
41
        if value > 0 {
385
40
            builder.limit(value);
386
40
        }
387
39
    }
388
80
    if let Some(value) = opts.offset {
389
80
        match opts.limit {
390
39
            None => builder.limit(-1).offset(value),
391
1
            Some(0) => builder.limit(-1).offset(value),
392
40
            _ => builder.offset(value),
393
        };
394
    }
395
80
    builder
396
80
}
397

            
398
/// Transforms model options to the SQL builder.
399
80
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
400
80
    if let Some(sort_cond) = opts.sort.as_ref() {
401
73
        for cond in sort_cond.iter() {
402
73
            let key = match cond.key {
403
49
                SortKey::Account => "account",
404
6
                SortKey::CreatedAt => "created_at",
405
4
                SortKey::ModifiedAt => "modified_at",
406
4
                SortKey::VerifiedAt => "verified_at",
407
2
                SortKey::ExpiredAt => "expired_at",
408
2
                SortKey::DisabledAt => "disabled_at",
409
6
                SortKey::Name => "name",
410
            };
411
73
            builder.order_by(key, !cond.asc);
412
        }
413
12
    }
414
80
    builder
415
80
}
416

            
417
/// Transforms query conditions and the model object to the SQL builder.
418
24
fn build_update_where<'a>(
419
24
    builder: &'a mut SqlBuilder,
420
24
    user_id: &str,
421
24
    updates: &Updates,
422
24
) -> Option<&'a mut SqlBuilder> {
423
24
    let mut count = 0;
424
24
    if let Some(value) = updates.modified_at.as_ref() {
425
23
        builder.set("modified_at", value.timestamp_millis());
426
23
        count += 1;
427
23
    }
428
24
    if let Some(value) = updates.verified_at.as_ref() {
429
4
        builder.set("verified_at", value.timestamp_millis());
430
4
        count += 1;
431
20
    }
432
24
    if let Some(value) = updates.expired_at.as_ref() {
433
4
        match value {
434
3
            None => {
435
3
                builder.set("expired_at", "NULL");
436
3
            }
437
1
            Some(value) => {
438
1
                builder.set("expired_at", value.timestamp_millis());
439
1
            }
440
        }
441
4
        count += 1;
442
20
    }
443
24
    if let Some(value) = updates.disabled_at.as_ref() {
444
8
        match value {
445
4
            None => {
446
4
                builder.set("disabled_at", "NULL");
447
4
            }
448
4
            Some(value) => {
449
4
                builder.set("disabled_at", value.timestamp_millis());
450
4
            }
451
        }
452
8
        count += 1;
453
16
    }
454
24
    if let Some(value) = updates.roles {
455
6
        builder.set(
456
6
            "roles",
457
6
            match serde_json::to_string(value) {
458
                Err(_) => quote("{}"),
459
6
                Ok(value) => quote(value.as_str()),
460
            },
461
        );
462
6
        count += 1;
463
18
    }
464
24
    if let Some(value) = updates.password.as_ref() {
465
7
        builder.set("password", quote(value));
466
7
        count += 1;
467
17
    }
468
24
    if let Some(value) = updates.salt.as_ref() {
469
7
        builder.set("salt", quote(value));
470
7
        count += 1;
471
17
    }
472
24
    if let Some(value) = updates.name {
473
16
        builder.set("name", quote(value));
474
16
        count += 1;
475
16
    }
476
24
    if let Some(value) = updates.info {
477
16
        match serde_json::to_string(value) {
478
            Err(_) => {
479
                builder.set("info", quote("{}"));
480
            }
481
16
            Ok(value) => {
482
16
                builder.set("info", quote(value));
483
16
            }
484
        }
485
16
        count += 1;
486
8
    }
487
24
    if count == 0 {
488
1
        return None;
489
23
    }
490
23

            
491
23
    builder.and_where_eq("user_id", quote(user_id));
492
23
    Some(builder)
493
24
}