1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use sql_builder::{quote, SqlBuilder};
6
use sqlx::SqlitePool;
7

            
8
use super::super::access_token::{AccessToken, AccessTokenModel, QueryCond};
9

            
10
/// Model instance.
11
pub struct Model {
12
    /// The associated database connection.
13
    conn: Arc<SqlitePool>,
14
}
15

            
16
/// SQLite schema.
17
#[derive(sqlx::FromRow)]
18
struct Schema {
19
    access_token: String,
20
    refresh_token: Option<String>,
21
    /// i64 as time tick from Epoch in milliseconds.
22
    expires_at: i64,
23
    scope: Option<String>,
24
    client_id: String,
25
    redirect_uri: String,
26
    user_id: String,
27
}
28

            
29
const TABLE_NAME: &'static str = "access_token";
30
const FIELDS: &'static [&'static str] = &[
31
    "access_token",
32
    "refresh_token",
33
    "expires_at",
34
    "scope",
35
    "client_id",
36
    "redirect_uri",
37
    "user_id",
38
];
39
const TABLE_INIT_SQL: &'static str = "\
40
    CREATE TABLE IF NOT EXISTS access_token (\
41
    access_token TEXT NOT NULL UNIQUE,\
42
    refresh_token TEXT,\
43
    expires_at INTEGER NOT NULL,\
44
    scope TEXT,\
45
    client_id TEXT NOT NULL,\
46
    redirect_uri TEXT NOT NULL,\
47
    user_id TEXT NOT NULL,\
48
    PRIMARY KEY (access_token))";
49

            
50
impl Model {
51
    /// To create the model instance with a database connection.
52
11
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
53
11
        let model = Model { conn };
54
21
        model.init().await?;
55
11
        Ok(model)
56
11
    }
57
}
58

            
59
#[async_trait]
60
impl AccessTokenModel for Model {
61
20
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
62
20
        let _ = sqlx::query(TABLE_INIT_SQL)
63
20
            .execute(self.conn.as_ref())
64
39
            .await?;
65
20
        Ok(())
66
40
    }
67

            
68
269
    async fn get(&self, access_token: &str) -> Result<Option<AccessToken>, Box<dyn StdError>> {
69
269
        let cond = QueryCond {
70
269
            access_token: Some(access_token),
71
269
            ..Default::default()
72
269
        };
73
269
        let sql = get_query_sql(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
74

            
75
269
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
76
269
            .fetch_one(self.conn.as_ref())
77
538
            .await;
78

            
79
269
        let row = match result {
80
30
            Err(e) => match e {
81
30
                sqlx::Error::RowNotFound => return Ok(None),
82
                _ => return Err(Box::new(e)),
83
            },
84
239
            Ok(row) => row,
85
239
        };
86
239
        Ok(Some(AccessToken {
87
239
            access_token: row.access_token,
88
239
            refresh_token: row.refresh_token,
89
239
            expires_at: Utc.timestamp_nanos(row.expires_at * 1000000),
90
239
            scope: row.scope,
91
239
            client_id: row.client_id,
92
239
            redirect_uri: row.redirect_uri,
93
239
            user_id: row.user_id,
94
239
        }))
95
538
    }
96

            
97
151
    async fn add(&self, token: &AccessToken) -> Result<(), Box<dyn StdError>> {
98
151
        let refresh_token = match token.refresh_token.as_deref() {
99
10
            None => "NULL".to_string(),
100
141
            Some(token) => quote(token),
101
        };
102
151
        let scope = match token.scope.as_deref() {
103
4
            None => "NULL".to_string(),
104
147
            Some(scope) => quote(scope),
105
        };
106
151
        let values = vec![
107
151
            quote(token.access_token.as_str()),
108
151
            refresh_token,
109
151
            token.expires_at.timestamp_millis().to_string(),
110
151
            scope,
111
151
            quote(token.client_id.as_str()),
112
151
            quote(token.redirect_uri.as_str()),
113
151
            quote(token.user_id.as_str()),
114
151
        ];
115
151
        let sql = SqlBuilder::insert_into(TABLE_NAME)
116
151
            .fields(FIELDS)
117
151
            .values(&values)
118
151
            .sql()?;
119
151
        let _ = sqlx::query(sql.as_str())
120
151
            .execute(self.conn.as_ref())
121
302
            .await?;
122
150
        Ok(())
123
302
    }
124

            
125
20
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
126
20
        let sql = get_query_sql(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
127
20
        let _ = sqlx::query(sql.as_str())
128
20
            .execute(self.conn.as_ref())
129
40
            .await?;
130
20
        Ok(())
131
40
    }
132
}
133

            
134
/// Transforms query conditions to the SQL builder.
135
289
fn get_query_sql<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
136
289
    if let Some(value) = cond.access_token {
137
273
        builder.and_where_eq("access_token", quote(value));
138
273
    }
139
289
    if let Some(value) = cond.refresh_token {
140
7
        builder.and_where_eq("refresh_token", quote(value));
141
282
    }
142
289
    if let Some(value) = cond.client_id {
143
3
        builder.and_where_eq("client_id", quote(value));
144
286
    }
145
289
    if let Some(value) = cond.user_id {
146
7
        builder.and_where_eq("user_id", quote(value));
147
282
    }
148
289
    builder
149
289
}