1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use sql_builder::{quote, SqlBuilder};
6
use sqlx::SqlitePool;
7

            
8
use super::super::refresh_token::{QueryCond, RefreshToken, RefreshTokenModel};
9

            
10
/// Model instance.
11
pub struct Model {
12
    /// The associated database connection.
13
    conn: Arc<SqlitePool>,
14
}
15

            
16
/// SQLite schema.
17
#[derive(sqlx::FromRow)]
18
struct Schema {
19
    refresh_token: String,
20
    /// i64 as time tick from Epoch in milliseconds.
21
    expires_at: i64,
22
    scope: Option<String>,
23
    client_id: String,
24
    redirect_uri: String,
25
    user_id: String,
26
}
27

            
28
const TABLE_NAME: &'static str = "refresh_token";
29
const FIELDS: &'static [&'static str] = &[
30
    "refresh_token",
31
    "expires_at",
32
    "scope",
33
    "client_id",
34
    "redirect_uri",
35
    "user_id",
36
];
37
const TABLE_INIT_SQL: &'static str = "\
38
    CREATE TABLE IF NOT EXISTS refresh_token (\
39
    refresh_token TEXT NOT NULL UNIQUE,\
40
    expires_at INTEGER NOT NULL,\
41
    scope TEXT,\
42
    client_id TEXT NOT NULL,\
43
    redirect_uri TEXT NOT NULL,\
44
    user_id TEXT NOT NULL,\
45
    PRIMARY KEY (refresh_token))";
46

            
47
impl Model {
48
    /// To create the model instance with a database connection.
49
22
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
50
22
        let model = Model { conn };
51
22
        model.init().await?;
52
22
        Ok(model)
53
22
    }
54
}
55

            
56
#[async_trait]
57
impl RefreshTokenModel for Model {
58
40
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
59
40
        let _ = sqlx::query(TABLE_INIT_SQL)
60
40
            .execute(self.conn.as_ref())
61
40
            .await?;
62
40
        Ok(())
63
80
    }
64

            
65
66
    async fn get(&self, refresh_token: &str) -> Result<Option<RefreshToken>, Box<dyn StdError>> {
66
66
        let cond = QueryCond {
67
66
            refresh_token: Some(refresh_token),
68
66
            ..Default::default()
69
66
        };
70
66
        let sql = get_query_sql(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
71

            
72
66
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
73
66
            .fetch_one(self.conn.as_ref())
74
66
            .await;
75

            
76
66
        let row = match result {
77
18
            Err(e) => match e {
78
18
                sqlx::Error::RowNotFound => return Ok(None),
79
                _ => return Err(Box::new(e)),
80
            },
81
48
            Ok(row) => row,
82
48
        };
83
48
        Ok(Some(RefreshToken {
84
48
            refresh_token: row.refresh_token,
85
48
            expires_at: Utc.timestamp_nanos(row.expires_at * 1000000),
86
48
            scope: row.scope,
87
48
            client_id: row.client_id,
88
48
            user_id: row.user_id,
89
48
            redirect_uri: row.redirect_uri,
90
48
        }))
91
132
    }
92

            
93
298
    async fn add(&self, token: &RefreshToken) -> Result<(), Box<dyn StdError>> {
94
298
        let scope = match token.scope.as_deref() {
95
6
            None => "NULL".to_string(),
96
292
            Some(scope) => quote(scope),
97
        };
98
298
        let values = vec![
99
298
            quote(token.refresh_token.as_str()),
100
298
            token.expires_at.timestamp_millis().to_string(),
101
298
            scope,
102
298
            quote(token.client_id.as_str()),
103
298
            quote(token.redirect_uri.as_str()),
104
298
            quote(token.user_id.as_str()),
105
298
        ];
106
298
        let sql = SqlBuilder::insert_into(TABLE_NAME)
107
298
            .fields(FIELDS)
108
298
            .values(&values)
109
298
            .sql()?;
110
298
        let _ = sqlx::query(sql.as_str())
111
298
            .execute(self.conn.as_ref())
112
298
            .await?;
113
296
        Ok(())
114
596
    }
115

            
116
36
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
117
36
        let sql = get_query_sql(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
118
36
        let _ = sqlx::query(sql.as_str())
119
36
            .execute(self.conn.as_ref())
120
36
            .await?;
121
36
        Ok(())
122
72
    }
123
}
124

            
125
/// Transforms query conditions to the SQL builder.
126
102
fn get_query_sql<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
127
102
    if let Some(value) = cond.refresh_token {
128
84
        builder.and_where_eq("refresh_token", quote(value));
129
84
    }
130
102
    if let Some(value) = cond.client_id {
131
6
        builder.and_where_eq("client_id", quote(value));
132
96
    }
133
102
    if let Some(value) = cond.user_id {
134
14
        builder.and_where_eq("user_id", quote(value));
135
88
    }
136
102
    builder
137
102
}