1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use sql_builder::{quote, SqlBuilder};
6
use sqlx::SqlitePool;
7

            
8
use super::super::access_token::{AccessToken, AccessTokenModel, QueryCond};
9

            
10
/// Model instance.
11
pub struct Model {
12
    /// The associated database connection.
13
    conn: Arc<SqlitePool>,
14
}
15

            
16
/// SQLite schema.
17
#[derive(sqlx::FromRow)]
18
struct Schema {
19
    access_token: String,
20
    refresh_token: Option<String>,
21
    /// i64 as time tick from Epoch in milliseconds.
22
    expires_at: i64,
23
    scope: Option<String>,
24
    client_id: String,
25
    redirect_uri: String,
26
    user_id: String,
27
}
28

            
29
const TABLE_NAME: &'static str = "access_token";
30
const FIELDS: &'static [&'static str] = &[
31
    "access_token",
32
    "refresh_token",
33
    "expires_at",
34
    "scope",
35
    "client_id",
36
    "redirect_uri",
37
    "user_id",
38
];
39
const TABLE_INIT_SQL: &'static str = "\
40
    CREATE TABLE IF NOT EXISTS access_token (\
41
    access_token TEXT NOT NULL UNIQUE,\
42
    refresh_token TEXT,\
43
    expires_at INTEGER NOT NULL,\
44
    scope TEXT,\
45
    client_id TEXT NOT NULL,\
46
    redirect_uri TEXT NOT NULL,\
47
    user_id TEXT NOT NULL,\
48
    PRIMARY KEY (access_token))";
49

            
50
impl Model {
51
    /// To create the model instance with a database connection.
52
22
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
53
22
        let model = Model { conn };
54
22
        model.init().await?;
55
22
        Ok(model)
56
22
    }
57
}
58

            
59
#[async_trait]
60
impl AccessTokenModel for Model {
61
40
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
62
40
        let _ = sqlx::query(TABLE_INIT_SQL)
63
40
            .execute(self.conn.as_ref())
64
40
            .await?;
65
40
        Ok(())
66
80
    }
67

            
68
538
    async fn get(&self, access_token: &str) -> Result<Option<AccessToken>, Box<dyn StdError>> {
69
538
        let cond = QueryCond {
70
538
            access_token: Some(access_token),
71
538
            ..Default::default()
72
538
        };
73
538
        let sql = get_query_sql(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
74

            
75
538
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
76
538
            .fetch_one(self.conn.as_ref())
77
538
            .await;
78

            
79
538
        let row = match result {
80
60
            Err(e) => match e {
81
60
                sqlx::Error::RowNotFound => return Ok(None),
82
                _ => return Err(Box::new(e)),
83
            },
84
478
            Ok(row) => row,
85
478
        };
86
478
        Ok(Some(AccessToken {
87
478
            access_token: row.access_token,
88
478
            refresh_token: row.refresh_token,
89
478
            expires_at: Utc.timestamp_nanos(row.expires_at * 1000000),
90
478
            scope: row.scope,
91
478
            client_id: row.client_id,
92
478
            redirect_uri: row.redirect_uri,
93
478
            user_id: row.user_id,
94
478
        }))
95
1076
    }
96

            
97
302
    async fn add(&self, token: &AccessToken) -> Result<(), Box<dyn StdError>> {
98
302
        let refresh_token = match token.refresh_token.as_deref() {
99
20
            None => "NULL".to_string(),
100
282
            Some(token) => quote(token),
101
        };
102
302
        let scope = match token.scope.as_deref() {
103
8
            None => "NULL".to_string(),
104
294
            Some(scope) => quote(scope),
105
        };
106
302
        let values = vec![
107
302
            quote(token.access_token.as_str()),
108
302
            refresh_token,
109
302
            token.expires_at.timestamp_millis().to_string(),
110
302
            scope,
111
302
            quote(token.client_id.as_str()),
112
302
            quote(token.redirect_uri.as_str()),
113
302
            quote(token.user_id.as_str()),
114
302
        ];
115
302
        let sql = SqlBuilder::insert_into(TABLE_NAME)
116
302
            .fields(FIELDS)
117
302
            .values(&values)
118
302
            .sql()?;
119
302
        let _ = sqlx::query(sql.as_str())
120
302
            .execute(self.conn.as_ref())
121
302
            .await?;
122
300
        Ok(())
123
604
    }
124

            
125
40
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
126
40
        let sql = get_query_sql(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
127
40
        let _ = sqlx::query(sql.as_str())
128
40
            .execute(self.conn.as_ref())
129
40
            .await?;
130
40
        Ok(())
131
80
    }
132
}
133

            
134
/// Transforms query conditions to the SQL builder.
135
578
fn get_query_sql<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
136
578
    if let Some(value) = cond.access_token {
137
546
        builder.and_where_eq("access_token", quote(value));
138
546
    }
139
578
    if let Some(value) = cond.refresh_token {
140
14
        builder.and_where_eq("refresh_token", quote(value));
141
564
    }
142
578
    if let Some(value) = cond.client_id {
143
6
        builder.and_where_eq("client_id", quote(value));
144
572
    }
145
578
    if let Some(value) = cond.user_id {
146
14
        builder.and_where_eq("user_id", quote(value));
147
564
    }
148
578
    builder
149
578
}