1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use sql_builder::{quote, SqlBuilder};
6
use sqlx::SqlitePool;
7

            
8
use super::super::login_session::{LoginSession, LoginSessionModel, QueryCond};
9

            
10
/// Model instance.
11
pub struct Model {
12
    /// The associated database connection.
13
    conn: Arc<SqlitePool>,
14
}
15

            
16
/// SQLite schema.
17
#[derive(sqlx::FromRow)]
18
struct Schema {
19
    session_id: String,
20
    /// i64 as time tick from Epoch in milliseconds.
21
    expires_at: i64,
22
    user_id: String,
23
}
24

            
25
const TABLE_NAME: &'static str = "login_session";
26
const FIELDS: &'static [&'static str] = &["session_id", "expires_at", "user_id"];
27
const TABLE_INIT_SQL: &'static str = "\
28
    CREATE TABLE IF NOT EXISTS login_session (\
29
    session_id TEXT NOT NULL UNIQUE,\
30
    expires_at INTEGER NOT NULL,\
31
    user_id TEXT NOT NULL,\
32
    PRIMARY KEY (session_id))";
33

            
34
impl Model {
35
    /// To create the model instance with a database connection.
36
11
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
37
11
        let model = Model { conn };
38
22
        model.init().await?;
39
11
        Ok(model)
40
11
    }
41
}
42

            
43
#[async_trait]
44
impl LoginSessionModel for Model {
45
20
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
46
20
        let _ = sqlx::query(TABLE_INIT_SQL)
47
20
            .execute(self.conn.as_ref())
48
40
            .await?;
49
20
        Ok(())
50
20
    }
51

            
52
149
    async fn get(&self, session_id: &str) -> Result<Option<LoginSession>, Box<dyn StdError>> {
53
149
        let cond = QueryCond {
54
149
            session_id: Some(session_id),
55
149
            ..Default::default()
56
149
        };
57
149
        let sql = get_query_sql(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
58
149

            
59
149
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
60
149
            .fetch_one(self.conn.as_ref())
61
298
            .await;
62
149

            
63
149
        let row = match result {
64
149
            Err(e) => match e {
65
149
                sqlx::Error::RowNotFound => return Ok(None),
66
149
                _ => return Err(Box::new(e)),
67
149
            },
68
149
            Ok(row) => row,
69
144
        };
70
144
        Ok(Some(LoginSession {
71
144
            session_id: row.session_id,
72
144
            expires_at: Utc.timestamp_nanos(row.expires_at * 1000000),
73
144
            user_id: row.user_id,
74
144
        }))
75
149
    }
76

            
77
153
    async fn add(&self, session: &LoginSession) -> Result<(), Box<dyn StdError>> {
78
153
        let values = vec![
79
153
            quote(session.session_id.as_str()),
80
153
            session.expires_at.timestamp_millis().to_string(),
81
153
            quote(session.user_id.as_str()),
82
153
        ];
83
153
        let sql = SqlBuilder::insert_into(TABLE_NAME)
84
153
            .fields(FIELDS)
85
153
            .values(&values)
86
153
            .sql()?;
87
153
        let _ = sqlx::query(sql.as_str())
88
153
            .execute(self.conn.as_ref())
89
306
            .await?;
90
153
        Ok(())
91
153
    }
92

            
93
142
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
94
142
        let sql = get_query_sql(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
95
142
        let _ = sqlx::query(sql.as_str())
96
142
            .execute(self.conn.as_ref())
97
284
            .await?;
98
142
        Ok(())
99
142
    }
100
}
101

            
102
/// Transforms query conditions to the SQL builder.
103
291
fn get_query_sql<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
104
291
    if let Some(value) = cond.session_id {
105
290
        builder.and_where_eq("session_id", quote(value));
106
290
    }
107
291
    if let Some(value) = cond.user_id {
108
1
        builder.and_where_eq("user_id", quote(value));
109
290
    }
110
291
    builder
111
291
}