1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use sql_builder::{quote, SqlBuilder};
6
use sqlx::SqlitePool;
7

            
8
use super::super::authorization_code::{AuthorizationCode, AuthorizationCodeModel, QueryCond};
9

            
10
/// Model instance.
11
pub struct Model {
12
    /// The associated database connection.
13
    conn: Arc<SqlitePool>,
14
}
15

            
16
/// SQLite schema.
17
#[derive(sqlx::FromRow)]
18
struct Schema {
19
    code: String,
20
    /// i64 as time tick from Epoch in milliseconds.
21
    expires_at: i64,
22
    redirect_uri: String,
23
    scope: Option<String>,
24
    client_id: String,
25
    user_id: String,
26
}
27

            
28
const TABLE_NAME: &'static str = "authorization_code";
29
const FIELDS: &'static [&'static str] = &[
30
    "code",
31
    "expires_at",
32
    "redirect_uri",
33
    "scope",
34
    "client_id",
35
    "user_id",
36
];
37
const TABLE_INIT_SQL: &'static str = "\
38
    CREATE TABLE IF NOT EXISTS authorization_code (\
39
    code TEXT NOT NULL UNIQUE,\
40
    expires_at INTEGER NOT NULL,\
41
    redirect_uri TEXT NOT NULL,\
42
    scope TEXT,\
43
    client_id TEXT NOT NULL,\
44
    user_id TEXT NOT NULL,\
45
    PRIMARY KEY (code))";
46

            
47
impl Model {
48
    /// To create the model instance with a database connection.
49
22
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
50
22
        let model = Model { conn };
51
22
        model.init().await?;
52
22
        Ok(model)
53
22
    }
54
}
55

            
56
#[async_trait]
57
impl AuthorizationCodeModel for Model {
58
40
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
59
40
        let _ = sqlx::query(TABLE_INIT_SQL)
60
40
            .execute(self.conn.as_ref())
61
40
            .await?;
62
40
        Ok(())
63
80
    }
64

            
65
290
    async fn get(&self, code: &str) -> Result<Option<AuthorizationCode>, Box<dyn StdError>> {
66
290
        let cond = QueryCond {
67
290
            code: Some(code),
68
290
            ..Default::default()
69
290
        };
70
290
        let sql = get_query_sql(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
71

            
72
290
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
73
290
            .fetch_one(self.conn.as_ref())
74
290
            .await;
75

            
76
290
        let row = match result {
77
18
            Err(e) => match e {
78
18
                sqlx::Error::RowNotFound => return Ok(None),
79
                _ => return Err(Box::new(e)),
80
            },
81
272
            Ok(row) => row,
82
272
        };
83
272
        Ok(Some(AuthorizationCode {
84
272
            code: row.code,
85
272
            expires_at: Utc.timestamp_nanos(row.expires_at * 1000000),
86
272
            redirect_uri: row.redirect_uri,
87
272
            scope: row.scope,
88
272
            client_id: row.client_id,
89
272
            user_id: row.user_id,
90
272
        }))
91
580
    }
92

            
93
310
    async fn add(&self, code: &AuthorizationCode) -> Result<(), Box<dyn StdError>> {
94
310
        let scope = match code.scope.as_deref() {
95
270
            None => "NULL".to_string(),
96
40
            Some(value) => quote(value),
97
        };
98
310
        let values = vec![
99
310
            quote(code.code.as_str()),
100
310
            code.expires_at.timestamp_millis().to_string(),
101
310
            quote(code.redirect_uri.as_str()),
102
310
            scope,
103
310
            quote(code.client_id.as_str()),
104
310
            quote(code.user_id.as_str()),
105
310
        ];
106
310
        let sql = SqlBuilder::insert_into(TABLE_NAME)
107
310
            .fields(FIELDS)
108
310
            .values(&values)
109
310
            .sql()?;
110
310
        let _ = sqlx::query(sql.as_str())
111
310
            .execute(self.conn.as_ref())
112
310
            .await?;
113
308
        Ok(())
114
620
    }
115

            
116
274
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
117
274
        let sql = get_query_sql(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
118
274
        let _ = sqlx::query(sql.as_str())
119
274
            .execute(self.conn.as_ref())
120
274
            .await?;
121
274
        Ok(())
122
548
    }
123
}
124

            
125
/// Transforms query conditions to the SQL builder.
126
564
fn get_query_sql<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
127
564
    if let Some(value) = cond.code {
128
546
        builder.and_where_eq("code", quote(value));
129
546
    }
130
564
    if let Some(value) = cond.client_id {
131
6
        builder.and_where_eq("client_id", quote(value));
132
558
    }
133
564
    if let Some(value) = cond.user_id {
134
14
        builder.and_where_eq("user_id", quote(value));
135
550
    }
136
564
    builder
137
564
}