1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use sql_builder::{quote, SqlBuilder};
6
use sqlx::SqlitePool;
7

            
8
use super::super::refresh_token::{QueryCond, RefreshToken, RefreshTokenModel};
9

            
10
/// Model instance.
11
pub struct Model {
12
    /// The associated database connection.
13
    conn: Arc<SqlitePool>,
14
}
15

            
16
/// SQLite schema.
17
#[derive(sqlx::FromRow)]
18
struct Schema {
19
    refresh_token: String,
20
    /// i64 as time tick from Epoch in milliseconds.
21
    expires_at: i64,
22
    scope: Option<String>,
23
    client_id: String,
24
    redirect_uri: String,
25
    user_id: String,
26
}
27

            
28
const TABLE_NAME: &'static str = "refresh_token";
29
const FIELDS: &'static [&'static str] = &[
30
    "refresh_token",
31
    "expires_at",
32
    "scope",
33
    "client_id",
34
    "redirect_uri",
35
    "user_id",
36
];
37
const TABLE_INIT_SQL: &'static str = "\
38
    CREATE TABLE IF NOT EXISTS refresh_token (\
39
    refresh_token TEXT NOT NULL UNIQUE,\
40
    expires_at INTEGER NOT NULL,\
41
    scope TEXT,\
42
    client_id TEXT NOT NULL,\
43
    redirect_uri TEXT NOT NULL,\
44
    user_id TEXT NOT NULL,\
45
    PRIMARY KEY (refresh_token))";
46

            
47
impl Model {
48
    /// To create the model instance with a database connection.
49
11
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
50
11
        let model = Model { conn };
51
22
        model.init().await?;
52
11
        Ok(model)
53
11
    }
54
}
55

            
56
#[async_trait]
57
impl RefreshTokenModel for Model {
58
20
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
59
20
        let _ = sqlx::query(TABLE_INIT_SQL)
60
20
            .execute(self.conn.as_ref())
61
40
            .await?;
62
20
        Ok(())
63
20
    }
64

            
65
33
    async fn get(&self, refresh_token: &str) -> Result<Option<RefreshToken>, Box<dyn StdError>> {
66
33
        let cond = QueryCond {
67
33
            refresh_token: Some(refresh_token),
68
33
            ..Default::default()
69
33
        };
70
33
        let sql = get_query_sql(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
71
33

            
72
33
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
73
33
            .fetch_one(self.conn.as_ref())
74
66
            .await;
75
33

            
76
33
        let row = match result {
77
33
            Err(e) => match e {
78
33
                sqlx::Error::RowNotFound => return Ok(None),
79
33
                _ => return Err(Box::new(e)),
80
33
            },
81
33
            Ok(row) => row,
82
24
        };
83
24
        Ok(Some(RefreshToken {
84
24
            refresh_token: row.refresh_token,
85
24
            expires_at: Utc.timestamp_nanos(row.expires_at * 1000000),
86
24
            scope: row.scope,
87
24
            client_id: row.client_id,
88
24
            user_id: row.user_id,
89
24
            redirect_uri: row.redirect_uri,
90
24
        }))
91
33
    }
92

            
93
149
    async fn add(&self, token: &RefreshToken) -> Result<(), Box<dyn StdError>> {
94
149
        let scope = match token.scope.as_deref() {
95
149
            None => "NULL".to_string(),
96
149
            Some(scope) => quote(scope),
97
149
        };
98
149
        let values = vec![
99
149
            quote(token.refresh_token.as_str()),
100
149
            token.expires_at.timestamp_millis().to_string(),
101
149
            scope,
102
149
            quote(token.client_id.as_str()),
103
149
            quote(token.redirect_uri.as_str()),
104
149
            quote(token.user_id.as_str()),
105
149
        ];
106
149
        let sql = SqlBuilder::insert_into(TABLE_NAME)
107
149
            .fields(FIELDS)
108
149
            .values(&values)
109
149
            .sql()?;
110
149
        let _ = sqlx::query(sql.as_str())
111
149
            .execute(self.conn.as_ref())
112
298
            .await?;
113
149
        Ok(())
114
149
    }
115

            
116
18
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
117
18
        let sql = get_query_sql(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
118
18
        let _ = sqlx::query(sql.as_str())
119
18
            .execute(self.conn.as_ref())
120
36
            .await?;
121
18
        Ok(())
122
18
    }
123
}
124

            
125
/// Transforms query conditions to the SQL builder.
126
51
fn get_query_sql<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
127
51
    if let Some(value) = cond.refresh_token {
128
42
        builder.and_where_eq("refresh_token", quote(value));
129
42
    }
130
51
    if let Some(value) = cond.client_id {
131
3
        builder.and_where_eq("client_id", quote(value));
132
48
    }
133
51
    if let Some(value) = cond.user_id {
134
7
        builder.and_where_eq("user_id", quote(value));
135
44
    }
136
51
    builder
137
51
}