1
use std::{collections::HashMap, error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use serde_json;
7
use sql_builder::{quote, SqlBuilder};
8
use sqlx::SqlitePool;
9

            
10
use super::{
11
    super::user::{
12
        Cursor, ListOptions, ListQueryCond, QueryCond, SortKey, Updates, User, UserModel,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    user_id: String,
34
    account: String,
35
    /// i64 as time tick from Epoch in milliseconds.
36
    created_at: i64,
37
    /// i64 as time tick from Epoch in milliseconds.
38
    modified_at: i64,
39
    /// i64 as time tick from Epoch in milliseconds.
40
    verified_at: Option<i64>,
41
    /// i64 as time tick from Epoch in milliseconds.
42
    expired_at: Option<i64>,
43
    /// i64 as time tick from Epoch in milliseconds.
44
    disabled_at: Option<i64>,
45
    /// JSON string value such as `{"role1":true,"role2":false}`.
46
    roles: String,
47
    password: String,
48
    salt: String,
49
    name: String,
50
    info: String,
51
}
52

            
53
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
54
#[derive(sqlx::FromRow)]
55
struct CountSchema {
56
    #[sqlx(rename = "COUNT(*)")]
57
    count: i64,
58
}
59

            
60
const TABLE_NAME: &'static str = "user";
61
const FIELDS: &'static [&'static str] = &[
62
    "user_id",
63
    "account",
64
    "created_at",
65
    "modified_at",
66
    "verified_at",
67
    "expired_at",
68
    "disabled_at",
69
    "roles",
70
    "password",
71
    "salt",
72
    "name",
73
    "info",
74
];
75
const TABLE_INIT_SQL: &'static str = "\
76
    CREATE TABLE IF NOT EXISTS user (\
77
    user_id TEXT NOT NULL UNIQUE,\
78
    account TEXT NOT NULL UNIQUE,\
79
    created_at INTEGER NOT NULL,\
80
    modified_at INTEGER NOT NULL,\
81
    verified_at INTEGER,\
82
    expired_at INTEGER,\
83
    disabled_at INTEGER,\
84
    roles TEXT,\
85
    password TEXT NOT NULL,\
86
    salt TEXT NOT NULL,\
87
    name TEXT NOT NULL,\
88
    info TEXT,\
89
    PRIMARY KEY (user_id))";
90

            
91
impl Model {
92
    /// To create the model instance with a database connection.
93
22
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
94
22
        let model = Model { conn };
95
22
        model.init().await?;
96
22
        Ok(model)
97
22
    }
98
}
99

            
100
#[async_trait]
101
impl UserModel for Model {
102
40
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
103
40
        let _ = sqlx::query(TABLE_INIT_SQL)
104
40
            .execute(self.conn.as_ref())
105
40
            .await?;
106
40
        Ok(())
107
80
    }
108

            
109
42
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
110
42
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
111

            
112
42
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
113
42
            .fetch_one(self.conn.as_ref())
114
42
            .await;
115

            
116
42
        let row = match result {
117
            Err(e) => {
118
                return Err(Box::new(e));
119
            }
120
42
            Ok(row) => row,
121
42
        };
122
42
        Ok(row.count as u64)
123
84
    }
124

            
125
    async fn list(
126
        &self,
127
        opts: &ListOptions,
128
        cursor: Option<Box<dyn Cursor>>,
129
166
    ) -> Result<(Vec<User>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
130
166
        let mut cursor = match cursor {
131
144
            None => Box::new(DbCursor::new()),
132
22
            Some(cursor) => cursor,
133
        };
134

            
135
166
        let mut opts = ListOptions { ..*opts };
136
166
        if let Some(offset) = opts.offset {
137
28
            opts.offset = Some(offset + cursor.offset());
138
138
        } else {
139
138
            opts.offset = Some(cursor.offset());
140
138
        }
141
166
        let opts_limit = opts.limit;
142
166
        if let Some(limit) = opts_limit {
143
88
            if limit > 0 {
144
86
                if cursor.offset() >= limit {
145
6
                    return Ok((vec![], None));
146
80
                }
147
80
                opts.limit = Some(limit - cursor.offset());
148
2
            }
149
78
        }
150
160
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
151
160
        build_limit_offset(&mut builder, &opts);
152
160
        build_sort(&mut builder, &opts);
153
160
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
154

            
155
160
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
156
160

            
157
160
        let mut count: u64 = 0;
158
160
        let mut list = vec![];
159
2402
        while let Some(row) = rows.try_next().await? {
160
2272
            let _ = cursor.as_mut().try_next().await?;
161
2272
            let roles: HashMap<String, bool> = match serde_json::from_str(row.roles.as_str()) {
162
                Err(_) => HashMap::new(),
163
2272
                Ok(roles) => roles,
164
            };
165
2272
            list.push(User {
166
2272
                user_id: row.user_id,
167
2272
                account: row.account,
168
2272
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
169
2272
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
170
2272
                verified_at: match row.verified_at {
171
234
                    None => None,
172
2038
                    Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
173
                },
174
2272
                expired_at: match row.expired_at {
175
2154
                    None => None,
176
118
                    Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
177
                },
178
2272
                disabled_at: match row.disabled_at {
179
2150
                    None => None,
180
122
                    Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
181
                },
182
2272
                roles,
183
2272
                password: row.password,
184
2272
                salt: row.salt,
185
2272
                name: row.name,
186
2272
                info: serde_json::from_str(row.info.as_str())?,
187
            });
188
2272
            if let Some(limit) = opts_limit {
189
888
                if limit > 0 && cursor.offset() >= limit {
190
14
                    if let Some(cursor_max) = opts.cursor_max {
191
12
                        if (count + 1) >= cursor_max {
192
6
                            return Ok((list, Some(cursor)));
193
6
                        }
194
2
                    }
195
8
                    return Ok((list, None));
196
874
                }
197
1384
            }
198
2258
            if let Some(cursor_max) = opts.cursor_max {
199
2052
                count += 1;
200
2052
                if count >= cursor_max {
201
16
                    return Ok((list, Some(cursor)));
202
2036
                }
203
206
            }
204
        }
205
130
        Ok((list, None))
206
332
    }
207

            
208
898
    async fn get(&self, cond: &QueryCond) -> Result<Option<User>, Box<dyn StdError>> {
209
898
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
210

            
211
898
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
212
898
            .fetch_one(self.conn.as_ref())
213
898
            .await;
214

            
215
898
        let row = match result {
216
22
            Err(e) => match e {
217
22
                sqlx::Error::RowNotFound => return Ok(None),
218
                _ => return Err(Box::new(e)),
219
            },
220
876
            Ok(row) => row,
221
        };
222

            
223
876
        let roles: HashMap<String, bool> = match serde_json::from_str(row.roles.as_str()) {
224
            Err(_) => HashMap::new(),
225
876
            Ok(roles) => roles,
226
        };
227
        Ok(Some(User {
228
876
            user_id: row.user_id,
229
876
            account: row.account,
230
876
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
231
876
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
232
876
            verified_at: match row.verified_at {
233
24
                None => None,
234
852
                Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
235
            },
236
876
            expired_at: match row.expired_at {
237
856
                None => None,
238
20
                Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
239
            },
240
876
            disabled_at: match row.disabled_at {
241
854
                None => None,
242
22
                Some(value) => Some(Utc.timestamp_nanos(value * 1000000)),
243
            },
244
876
            roles,
245
876
            password: row.password,
246
876
            salt: row.salt,
247
876
            name: row.name,
248
876
            info: serde_json::from_str(row.info.as_str())?,
249
        }))
250
1796
    }
251

            
252
1098
    async fn add(&self, user: &User) -> Result<(), Box<dyn StdError>> {
253
1098
        let roles = match serde_json::to_string(&user.roles) {
254
            Err(_) => quote("{}"),
255
1098
            Ok(value) => quote(value.as_str()),
256
        };
257
1098
        let info = match serde_json::to_string(&user.info) {
258
            Err(_) => quote("{}"),
259
1098
            Ok(value) => quote(value.as_str()),
260
        };
261
1098
        let values = vec![
262
1098
            quote(user.user_id.as_str()),
263
1098
            quote(user.account.to_lowercase().as_str()),
264
1098
            user.created_at.timestamp_millis().to_string(),
265
1098
            user.modified_at.timestamp_millis().to_string(),
266
1098
            match user.verified_at {
267
62
                None => "NULL".to_string(),
268
1036
                Some(value) => value.timestamp_millis().to_string(),
269
            },
270
1098
            match user.expired_at {
271
1076
                None => "NULL".to_string(),
272
22
                Some(value) => value.timestamp_millis().to_string(),
273
            },
274
1098
            match user.disabled_at {
275
1082
                None => "NULL".to_string(),
276
16
                Some(value) => value.timestamp_millis().to_string(),
277
            },
278
1098
            roles,
279
1098
            quote(user.password.as_str()),
280
1098
            quote(user.salt.as_str()),
281
1098
            quote(user.name.as_str()),
282
1098
            info,
283
        ];
284
1098
        let sql = SqlBuilder::insert_into(TABLE_NAME)
285
1098
            .fields(FIELDS)
286
1098
            .values(&values)
287
1098
            .sql()?;
288
1098
        let _ = sqlx::query(sql.as_str())
289
1098
            .execute(self.conn.as_ref())
290
1098
            .await?;
291
1094
        Ok(())
292
2196
    }
293

            
294
10
    async fn del(&self, user_id: &str) -> Result<(), Box<dyn StdError>> {
295
10
        let sql = SqlBuilder::delete_from(TABLE_NAME)
296
10
            .and_where_eq("user_id", quote(user_id))
297
10
            .sql()?;
298
10
        let _ = sqlx::query(sql.as_str())
299
10
            .execute(self.conn.as_ref())
300
10
            .await?;
301
10
        Ok(())
302
20
    }
303

            
304
48
    async fn update(&self, user_id: &str, updates: &Updates) -> Result<(), Box<dyn StdError>> {
305
46
        let sql =
306
48
            match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), user_id, updates) {
307
2
                None => return Ok(()),
308
46
                Some(builder) => builder.sql()?,
309
            };
310
46
        let _ = sqlx::query(sql.as_str())
311
46
            .execute(self.conn.as_ref())
312
46
            .await?;
313
46
        Ok(())
314
96
    }
315
}
316

            
317
impl DbCursor {
318
    /// To create the cursor instance.
319
144
    pub fn new() -> Self {
320
144
        DbCursor { offset: 0 }
321
144
    }
322
}
323

            
324
#[async_trait]
325
impl Cursor for DbCursor {
326
2272
    async fn try_next(&mut self) -> Result<Option<User>, Box<dyn StdError>> {
327
2272
        self.offset += 1;
328
2272
        Ok(None)
329
4544
    }
330

            
331
1220
    fn offset(&self) -> u64 {
332
1220
        self.offset
333
1220
    }
334
}
335

            
336
/// Transforms query conditions to the SQL builder.
337
898
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
338
898
    if let Some(value) = cond.user_id {
339
598
        builder.and_where_eq("user_id", quote(value));
340
598
    }
341
898
    if let Some(value) = cond.account {
342
300
        builder.and_where_eq("account", quote(value.to_lowercase().as_str()));
343
598
    }
344
898
    builder
345
898
}
346

            
347
/// Transforms query conditions to the SQL builder.
348
202
fn build_list_where<'a>(
349
202
    builder: &'a mut SqlBuilder,
350
202
    cond: &ListQueryCond<'a>,
351
202
) -> &'a mut SqlBuilder {
352
202
    if let Some(value) = cond.user_id {
353
4
        builder.and_where_eq("user_id", quote(value));
354
198
    }
355
202
    if let Some(value) = cond.account {
356
22
        builder.and_where_eq("account", quote(value.to_lowercase().as_str()));
357
180
    }
358
202
    if let Some(value) = cond.account_contains {
359
70
        build_where_like(builder, "account", value.to_lowercase().as_str());
360
132
    }
361
202
    if let Some(value) = cond.verified_at {
362
12
        if value {
363
8
            builder.and_where_is_not_null("verified_at");
364
8
        } else {
365
4
            builder.and_where_is_null("verified_at");
366
4
        }
367
190
    }
368
202
    if let Some(value) = cond.disabled_at {
369
8
        if value {
370
4
            builder.and_where_is_not_null("disabled_at");
371
4
        } else {
372
4
            builder.and_where_is_null("disabled_at");
373
4
        }
374
194
    }
375
202
    if let Some(value) = cond.name_contains {
376
12
        build_where_like(builder, "name", value);
377
190
    }
378
202
    builder
379
202
}
380

            
381
/// Transforms model options to the SQL builder.
382
160
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
383
160
    if let Some(value) = opts.limit {
384
82
        if value > 0 {
385
80
            builder.limit(value);
386
80
        }
387
78
    }
388
160
    if let Some(value) = opts.offset {
389
160
        match opts.limit {
390
78
            None => builder.limit(-1).offset(value),
391
2
            Some(0) => builder.limit(-1).offset(value),
392
80
            _ => builder.offset(value),
393
        };
394
    }
395
160
    builder
396
160
}
397

            
398
/// Transforms model options to the SQL builder.
399
160
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
400
160
    if let Some(sort_cond) = opts.sort.as_ref() {
401
146
        for cond in sort_cond.iter() {
402
146
            let key = match cond.key {
403
98
                SortKey::Account => "account",
404
12
                SortKey::CreatedAt => "created_at",
405
8
                SortKey::ModifiedAt => "modified_at",
406
8
                SortKey::VerifiedAt => "verified_at",
407
4
                SortKey::ExpiredAt => "expired_at",
408
4
                SortKey::DisabledAt => "disabled_at",
409
12
                SortKey::Name => "name",
410
            };
411
146
            builder.order_by(key, !cond.asc);
412
        }
413
24
    }
414
160
    builder
415
160
}
416

            
417
/// Transforms query conditions and the model object to the SQL builder.
418
48
fn build_update_where<'a>(
419
48
    builder: &'a mut SqlBuilder,
420
48
    user_id: &str,
421
48
    updates: &Updates,
422
48
) -> Option<&'a mut SqlBuilder> {
423
48
    let mut count = 0;
424
48
    if let Some(value) = updates.modified_at.as_ref() {
425
46
        builder.set("modified_at", value.timestamp_millis());
426
46
        count += 1;
427
46
    }
428
48
    if let Some(value) = updates.verified_at.as_ref() {
429
8
        builder.set("verified_at", value.timestamp_millis());
430
8
        count += 1;
431
40
    }
432
48
    if let Some(value) = updates.expired_at.as_ref() {
433
8
        match value {
434
6
            None => {
435
6
                builder.set("expired_at", "NULL");
436
6
            }
437
2
            Some(value) => {
438
2
                builder.set("expired_at", value.timestamp_millis());
439
2
            }
440
        }
441
8
        count += 1;
442
40
    }
443
48
    if let Some(value) = updates.disabled_at.as_ref() {
444
16
        match value {
445
8
            None => {
446
8
                builder.set("disabled_at", "NULL");
447
8
            }
448
8
            Some(value) => {
449
8
                builder.set("disabled_at", value.timestamp_millis());
450
8
            }
451
        }
452
16
        count += 1;
453
32
    }
454
48
    if let Some(value) = updates.roles {
455
12
        builder.set(
456
12
            "roles",
457
12
            match serde_json::to_string(value) {
458
                Err(_) => quote("{}"),
459
12
                Ok(value) => quote(value.as_str()),
460
            },
461
        );
462
12
        count += 1;
463
36
    }
464
48
    if let Some(value) = updates.password.as_ref() {
465
14
        builder.set("password", quote(value));
466
14
        count += 1;
467
34
    }
468
48
    if let Some(value) = updates.salt.as_ref() {
469
14
        builder.set("salt", quote(value));
470
14
        count += 1;
471
34
    }
472
48
    if let Some(value) = updates.name {
473
32
        builder.set("name", quote(value));
474
32
        count += 1;
475
32
    }
476
48
    if let Some(value) = updates.info {
477
32
        match serde_json::to_string(value) {
478
            Err(_) => {
479
                builder.set("info", quote("{}"));
480
            }
481
32
            Ok(value) => {
482
32
                builder.set("info", quote(value));
483
32
            }
484
        }
485
32
        count += 1;
486
16
    }
487
48
    if count == 0 {
488
2
        return None;
489
46
    }
490
46

            
491
46
    builder.and_where_eq("user_id", quote(user_id));
492
46
    Some(builder)
493
48
}