1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::{
10
    super::client::{
11
        Client, ClientModel, Cursor, ListOptions, ListQueryCond, QueryCond, SortKey,
12
        UpdateQueryCond, Updates,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    client_id: String,
34
    /// i64 as time tick from Epoch in milliseconds.
35
    created_at: i64,
36
    /// i64 as time tick from Epoch in milliseconds.
37
    modified_at: i64,
38
    client_secret: Option<String>,
39
    redirect_uris: String,
40
    /// Space-separated value such as `scope1 scope2`.
41
    scopes: String,
42
    user_id: String,
43
    name: String,
44
    image_url: Option<String>,
45
}
46

            
47
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
48
#[derive(sqlx::FromRow)]
49
struct CountSchema {
50
    #[sqlx(rename = "COUNT(*)")]
51
    count: i64,
52
}
53

            
54
const TABLE_NAME: &'static str = "client";
55
const FIELDS: &'static [&'static str] = &[
56
    "client_id",
57
    "created_at",
58
    "modified_at",
59
    "client_secret",
60
    "redirect_uris",
61
    "scopes",
62
    "user_id",
63
    "name",
64
    "image_url",
65
];
66
const TABLE_INIT_SQL: &'static str = "\
67
    CREATE TABLE IF NOT EXISTS client (\
68
    client_id TEXT NOT NULL UNIQUE,\
69
    created_at INTEGER NOT NULL,\
70
    modified_at INTEGER NOT NULL,\
71
    client_secret TEXT,\
72
    redirect_uris TEXT NOT NULL,\
73
    scopes TEXT NOT NULL,\
74
    user_id TEXT NOT NULL,\
75
    name TEXT NOT NULL,\
76
    image_url TEXT,\
77
    PRIMARY KEY (client_id))";
78

            
79
impl Model {
80
    /// To create the model instance with a database connection.
81
22
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
82
22
        let model = Model { conn };
83
22
        model.init().await?;
84
22
        Ok(model)
85
22
    }
86
}
87

            
88
#[async_trait]
89
impl ClientModel for Model {
90
40
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
91
40
        let _ = sqlx::query(TABLE_INIT_SQL)
92
40
            .execute(self.conn.as_ref())
93
40
            .await?;
94
40
        Ok(())
95
80
    }
96

            
97
20
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
98
20
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
99

            
100
20
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
101
20
            .fetch_one(self.conn.as_ref())
102
20
            .await;
103

            
104
20
        let row = match result {
105
            Err(e) => {
106
                return Err(Box::new(e));
107
            }
108
20
            Ok(row) => row,
109
20
        };
110
20
        Ok(row.count as u64)
111
40
    }
112

            
113
    async fn list(
114
        &self,
115
        opts: &ListOptions,
116
        cursor: Option<Box<dyn Cursor>>,
117
114
    ) -> Result<(Vec<Client>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
118
114
        let mut cursor = match cursor {
119
92
            None => Box::new(DbCursor::new()),
120
22
            Some(cursor) => cursor,
121
        };
122

            
123
114
        let mut opts = ListOptions { ..*opts };
124
114
        if let Some(offset) = opts.offset {
125
28
            opts.offset = Some(offset + cursor.offset());
126
86
        } else {
127
86
            opts.offset = Some(cursor.offset());
128
86
        }
129
114
        let opts_limit = opts.limit;
130
114
        if let Some(limit) = opts_limit {
131
64
            if limit > 0 {
132
62
                if cursor.offset() >= limit {
133
6
                    return Ok((vec![], None));
134
56
                }
135
56
                opts.limit = Some(limit - cursor.offset());
136
2
            }
137
50
        }
138
108
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
139
108
        build_limit_offset(&mut builder, &opts);
140
108
        build_sort(&mut builder, &opts);
141
108
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
142

            
143
108
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
144
108

            
145
108
        let mut count: u64 = 0;
146
108
        let mut list = vec![];
147
1952
        while let Some(row) = rows.try_next().await? {
148
1878
            let _ = cursor.as_mut().try_next().await?;
149
1878
            let redirect_uris = row
150
1878
                .redirect_uris
151
1878
                .split(" ")
152
1878
                .filter_map(|x| {
153
1878
                    if x.len() > 0 {
154
1734
                        Some(x.to_string())
155
                    } else {
156
144
                        None
157
                    }
158
1878
                })
159
1878
                .collect();
160
1878
            let scopes = row
161
1878
                .scopes
162
1878
                .split(" ")
163
1878
                .filter_map(|x| {
164
1878
                    if x.len() > 0 {
165
                        Some(x.to_string())
166
                    } else {
167
1878
                        None
168
                    }
169
1878
                })
170
1878
                .collect();
171
1878
            list.push(Client {
172
1878
                client_id: row.client_id,
173
1878
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
174
1878
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
175
1878
                client_secret: row.client_secret,
176
1878
                redirect_uris,
177
1878
                scopes,
178
1878
                user_id: row.user_id,
179
1878
                name: row.name,
180
1878
                image_url: row.image_url,
181
1878
            });
182
1878
            if let Some(limit) = opts_limit {
183
978
                if limit > 0 && cursor.offset() >= limit {
184
18
                    if let Some(cursor_max) = opts.cursor_max {
185
16
                        if (count + 1) >= cursor_max {
186
6
                            return Ok((list, Some(cursor)));
187
10
                        }
188
2
                    }
189
12
                    return Ok((list, None));
190
960
                }
191
900
            }
192
1860
            if let Some(cursor_max) = opts.cursor_max {
193
1744
                count += 1;
194
1744
                if count >= cursor_max {
195
16
                    return Ok((list, Some(cursor)));
196
1728
                }
197
116
            }
198
        }
199
74
        Ok((list, None))
200
228
    }
201

            
202
1726
    async fn get(&self, cond: &QueryCond) -> Result<Option<Client>, Box<dyn StdError>> {
203
1726
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
204

            
205
1726
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
206
1726
            .fetch_one(self.conn.as_ref())
207
1726
            .await;
208

            
209
1726
        let row = match result {
210
34
            Err(e) => match e {
211
34
                sqlx::Error::RowNotFound => return Ok(None),
212
                _ => return Err(Box::new(e)),
213
            },
214
1692
            Ok(row) => row,
215
1692
        };
216
1692

            
217
1692
        let redirect_uris = row
218
1692
            .redirect_uris
219
1692
            .split(" ")
220
1722
            .filter_map(|x| {
221
1722
                if x.len() > 0 {
222
1682
                    Some(x.to_string())
223
                } else {
224
40
                    None
225
                }
226
1722
            })
227
1692
            .collect();
228
1692
        let scopes = row
229
1692
            .scopes
230
1692
            .split(" ")
231
1724
            .filter_map(|x| {
232
1724
                if x.len() > 0 {
233
256
                    Some(x.to_string())
234
                } else {
235
1468
                    None
236
                }
237
1724
            })
238
1692
            .collect();
239
1692
        Ok(Some(Client {
240
1692
            client_id: row.client_id,
241
1692
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
242
1692
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
243
1692
            client_secret: row.client_secret,
244
1692
            redirect_uris,
245
1692
            scopes,
246
1692
            user_id: row.user_id,
247
1692
            name: row.name,
248
1692
            image_url: row.image_url,
249
1692
        }))
250
3452
    }
251

            
252
1068
    async fn add(&self, client: &Client) -> Result<(), Box<dyn StdError>> {
253
1068
        let client_secret = match client.client_secret.as_deref() {
254
1044
            None => "NULL".to_string(),
255
24
            Some(value) => quote(value),
256
        };
257
1068
        let image_url = match client.image_url.as_deref() {
258
1060
            None => "NULL".to_string(),
259
8
            Some(value) => quote(value),
260
        };
261
1068
        let values = vec![
262
1068
            quote(client.client_id.as_str()),
263
1068
            client.created_at.timestamp_millis().to_string(),
264
1068
            client.modified_at.timestamp_millis().to_string(),
265
1068
            client_secret,
266
1068
            quote(client.redirect_uris.join(" ")),
267
1068
            quote(client.scopes.join(" ")),
268
1068
            quote(client.user_id.as_str()),
269
1068
            quote(client.name.as_str()),
270
1068
            image_url,
271
1068
        ];
272
1068
        let sql = SqlBuilder::insert_into(TABLE_NAME)
273
1068
            .fields(FIELDS)
274
1068
            .values(&values)
275
1068
            .sql()?;
276
1068
        let _ = sqlx::query(sql.as_str())
277
1068
            .execute(self.conn.as_ref())
278
1068
            .await?;
279
1066
        Ok(())
280
2136
    }
281

            
282
20
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
283
20
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
284
20
        let _ = sqlx::query(sql.as_str())
285
20
            .execute(self.conn.as_ref())
286
20
            .await?;
287
20
        Ok(())
288
40
    }
289

            
290
    async fn update(
291
        &self,
292
        cond: &UpdateQueryCond,
293
        updates: &Updates,
294
28
    ) -> Result<(), Box<dyn StdError>> {
295
28
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
296
        {
297
2
            None => return Ok(()),
298
26
            Some(builder) => builder.sql()?,
299
        };
300
26
        let _ = sqlx::query(sql.as_str())
301
26
            .execute(self.conn.as_ref())
302
26
            .await?;
303
26
        Ok(())
304
56
    }
305
}
306

            
307
impl DbCursor {
308
    /// To create the cursor instance.
309
92
    pub fn new() -> Self {
310
92
        DbCursor { offset: 0 }
311
92
    }
312
}
313

            
314
#[async_trait]
315
impl Cursor for DbCursor {
316
1878
    async fn try_next(&mut self) -> Result<Option<Client>, Box<dyn StdError>> {
317
1878
        self.offset += 1;
318
1878
        Ok(None)
319
3756
    }
320

            
321
1210
    fn offset(&self) -> u64 {
322
1210
        self.offset
323
1210
    }
324
}
325

            
326
/// Transforms query conditions to the SQL builder.
327
1746
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
328
1746
    if let Some(value) = cond.user_id {
329
24
        builder.and_where_eq("user_id", quote(value));
330
1722
    }
331
1746
    if let Some(value) = cond.client_id {
332
1742
        builder.and_where_eq("client_id", quote(value));
333
1742
    }
334
1746
    builder
335
1746
}
336

            
337
/// Transforms query conditions to the SQL builder.
338
128
fn build_list_where<'a>(
339
128
    builder: &'a mut SqlBuilder,
340
128
    cond: &ListQueryCond<'a>,
341
128
) -> &'a mut SqlBuilder {
342
128
    if let Some(value) = cond.user_id {
343
60
        builder.and_where_eq("user_id", quote(value));
344
68
    }
345
128
    if let Some(value) = cond.client_id {
346
8
        builder.and_where_eq("client_id", quote(value));
347
120
    }
348
128
    if let Some(value) = cond.name_contains {
349
16
        build_where_like(builder, "name", value);
350
112
    }
351
128
    builder
352
128
}
353

            
354
/// Transforms model options to the SQL builder.
355
108
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
356
108
    if let Some(value) = opts.limit {
357
58
        if value > 0 {
358
56
            builder.limit(value);
359
56
        }
360
50
    }
361
108
    if let Some(value) = opts.offset {
362
108
        match opts.limit {
363
50
            None => builder.limit(-1).offset(value),
364
2
            Some(0) => builder.limit(-1).offset(value),
365
56
            _ => builder.offset(value),
366
        };
367
    }
368
108
    builder
369
108
}
370

            
371
/// Transforms model options to the SQL builder.
372
108
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
373
108
    if let Some(sort_cond) = opts.sort.as_ref() {
374
104
        for cond in sort_cond.iter() {
375
104
            let key = match cond.key {
376
22
                SortKey::CreatedAt => "created_at",
377
8
                SortKey::ModifiedAt => "modified_at",
378
74
                SortKey::Name => "name",
379
            };
380
104
            builder.order_by(key, !cond.asc);
381
        }
382
16
    }
383
108
    builder
384
108
}
385

            
386
/// Transforms query conditions and the model object to the SQL builder.
387
28
fn build_update_where<'a>(
388
28
    builder: &'a mut SqlBuilder,
389
28
    cond: &UpdateQueryCond<'a>,
390
28
    updates: &Updates,
391
28
) -> Option<&'a mut SqlBuilder> {
392
28
    let mut count = 0;
393
28
    if let Some(value) = updates.modified_at.as_ref() {
394
26
        builder.set("modified_at", value.timestamp_millis());
395
26
        count += 1;
396
26
    }
397
28
    if let Some(value) = updates.client_secret.as_ref() {
398
6
        match value {
399
2
            None => {
400
2
                builder.set("client_secret", "NULL");
401
2
            }
402
4
            Some(value) => {
403
4
                builder.set("client_secret", quote(value));
404
4
            }
405
        }
406
6
        count += 1;
407
22
    }
408
28
    if let Some(value) = updates.redirect_uris.as_ref() {
409
16
        builder.set("redirect_uris", quote(value.join(" ")));
410
16
        count += 1;
411
16
    }
412
28
    if let Some(value) = updates.scopes.as_ref() {
413
16
        builder.set("scopes", quote(value.join(" ")));
414
16
        count += 1;
415
16
    }
416
28
    if let Some(value) = updates.name.as_ref() {
417
10
        builder.set("name", quote(value));
418
10
        count += 1;
419
18
    }
420
28
    if let Some(value) = updates.image_url.as_ref() {
421
16
        match value {
422
8
            None => {
423
8
                builder.set("image_url", "NULL");
424
8
            }
425
8
            Some(value) => {
426
8
                builder.set("image_url", quote(value));
427
8
            }
428
        }
429
16
        count += 1;
430
12
    }
431
28
    if count == 0 {
432
2
        return None;
433
26
    }
434
26

            
435
26
    builder.and_where_eq("user_id", quote(cond.user_id));
436
26
    builder.and_where_eq("client_id", quote(cond.client_id));
437
26
    Some(builder)
438
28
}