1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use sql_builder::{quote, SqlBuilder};
6
use sqlx::SqlitePool;
7

            
8
use super::super::authorization_code::{AuthorizationCode, AuthorizationCodeModel, QueryCond};
9

            
10
/// Model instance.
11
pub struct Model {
12
    /// The associated database connection.
13
    conn: Arc<SqlitePool>,
14
}
15

            
16
/// SQLite schema.
17
#[derive(sqlx::FromRow)]
18
struct Schema {
19
    code: String,
20
    /// i64 as time tick from Epoch in milliseconds.
21
    expires_at: i64,
22
    redirect_uri: String,
23
    scope: Option<String>,
24
    client_id: String,
25
    user_id: String,
26
}
27

            
28
const TABLE_NAME: &'static str = "authorization_code";
29
const FIELDS: &'static [&'static str] = &[
30
    "code",
31
    "expires_at",
32
    "redirect_uri",
33
    "scope",
34
    "client_id",
35
    "user_id",
36
];
37
const TABLE_INIT_SQL: &'static str = "\
38
    CREATE TABLE IF NOT EXISTS authorization_code (\
39
    code TEXT NOT NULL UNIQUE,\
40
    expires_at INTEGER NOT NULL,\
41
    redirect_uri TEXT NOT NULL,\
42
    scope TEXT,\
43
    client_id TEXT NOT NULL,\
44
    user_id TEXT NOT NULL,\
45
    PRIMARY KEY (code))";
46

            
47
impl Model {
48
    /// To create the model instance with a database connection.
49
11
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
50
11
        let model = Model { conn };
51
11
        model.init().await?;
52
11
        Ok(model)
53
11
    }
54
}
55

            
56
#[async_trait]
57
impl AuthorizationCodeModel for Model {
58
20
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
59
20
        let _ = sqlx::query(TABLE_INIT_SQL)
60
20
            .execute(self.conn.as_ref())
61
20
            .await?;
62
20
        Ok(())
63
40
    }
64

            
65
145
    async fn get(&self, code: &str) -> Result<Option<AuthorizationCode>, Box<dyn StdError>> {
66
145
        let cond = QueryCond {
67
145
            code: Some(code),
68
145
            ..Default::default()
69
145
        };
70
145
        let sql = get_query_sql(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
71

            
72
145
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
73
145
            .fetch_one(self.conn.as_ref())
74
145
            .await;
75

            
76
145
        let row = match result {
77
9
            Err(e) => match e {
78
9
                sqlx::Error::RowNotFound => return Ok(None),
79
                _ => return Err(Box::new(e)),
80
            },
81
136
            Ok(row) => row,
82
136
        };
83
136
        Ok(Some(AuthorizationCode {
84
136
            code: row.code,
85
136
            expires_at: Utc.timestamp_nanos(row.expires_at * 1000000),
86
136
            redirect_uri: row.redirect_uri,
87
136
            scope: row.scope,
88
136
            client_id: row.client_id,
89
136
            user_id: row.user_id,
90
136
        }))
91
290
    }
92

            
93
155
    async fn add(&self, code: &AuthorizationCode) -> Result<(), Box<dyn StdError>> {
94
155
        let scope = match code.scope.as_deref() {
95
135
            None => "NULL".to_string(),
96
20
            Some(value) => quote(value),
97
        };
98
155
        let values = vec![
99
155
            quote(code.code.as_str()),
100
155
            code.expires_at.timestamp_millis().to_string(),
101
155
            quote(code.redirect_uri.as_str()),
102
155
            scope,
103
155
            quote(code.client_id.as_str()),
104
155
            quote(code.user_id.as_str()),
105
155
        ];
106
155
        let sql = SqlBuilder::insert_into(TABLE_NAME)
107
155
            .fields(FIELDS)
108
155
            .values(&values)
109
155
            .sql()?;
110
155
        let _ = sqlx::query(sql.as_str())
111
155
            .execute(self.conn.as_ref())
112
155
            .await?;
113
154
        Ok(())
114
310
    }
115

            
116
137
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
117
137
        let sql = get_query_sql(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
118
137
        let _ = sqlx::query(sql.as_str())
119
137
            .execute(self.conn.as_ref())
120
137
            .await?;
121
137
        Ok(())
122
274
    }
123
}
124

            
125
/// Transforms query conditions to the SQL builder.
126
282
fn get_query_sql<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
127
282
    if let Some(value) = cond.code {
128
273
        builder.and_where_eq("code", quote(value));
129
273
    }
130
282
    if let Some(value) = cond.client_id {
131
3
        builder.and_where_eq("client_id", quote(value));
132
279
    }
133
282
    if let Some(value) = cond.user_id {
134
7
        builder.and_where_eq("user_id", quote(value));
135
275
    }
136
282
    builder
137
282
}