1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{SqlBuilder, quote};
7
use sqlx::SqlitePool;
8

            
9
use super::super::dldata_buffer::{
10
    Cursor, DlDataBuffer, DlDataBufferModel, ListOptions, ListQueryCond, QueryCond, SortKey,
11
};
12

            
13
/// Model instance.
14
pub struct Model {
15
    /// The associated database connection.
16
    conn: Arc<SqlitePool>,
17
}
18

            
19
/// Cursor instance.
20
///
21
/// The SQLite implementation uses the original list options and the progress offset.
22
pub struct DbCursor {
23
    offset: u64,
24
}
25

            
26
/// SQLite schema.
27
#[derive(sqlx::FromRow)]
28
struct Schema {
29
    data_id: String,
30
    unit_id: String,
31
    unit_code: String,
32
    application_id: String,
33
    application_code: String,
34
    network_id: String,
35
    network_addr: String,
36
    device_id: String,
37
    /// i64 as time tick from Epoch in milliseconds.
38
    created_at: i64,
39
    /// i64 as time tick from Epoch in milliseconds.
40
    expired_at: i64,
41
}
42

            
43
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
44
#[derive(sqlx::FromRow)]
45
struct CountSchema {
46
    #[sqlx(rename = "COUNT(*)")]
47
    count: i64,
48
}
49

            
50
const TABLE_NAME: &'static str = "dldata_buffer";
51
const FIELDS: &'static [&'static str] = &[
52
    "data_id",
53
    "unit_id",
54
    "unit_code",
55
    "application_id",
56
    "application_code",
57
    "network_id",
58
    "network_addr",
59
    "device_id",
60
    "created_at",
61
    "expired_at",
62
];
63
const TABLE_INIT_SQL: &'static str = "\
64
    CREATE TABLE IF NOT EXISTS dldata_buffer (\
65
    data_id TEXT NOT NULL UNIQUE,\
66
    unit_id TEXT NOT NULL,\
67
    unit_code TEXT NOT NULL,\
68
    application_id TEXT NOT NULL,\
69
    application_code TEXT NOT NULL,\
70
    network_id TEXT NOT NULL,\
71
    network_addr TEXT NOT NULL,\
72
    device_id TEXT NOT NULL,\
73
    created_at INTEGER NOT NULL,\
74
    expired_at INTEGER NOT NULL,\
75
    PRIMARY KEY (data_id))";
76

            
77
impl Model {
78
    /// To create the model instance with a database connection.
79
56
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
80
56
        let model = Model { conn };
81
56
        model.init().await?;
82
56
        Ok(model)
83
56
    }
84
}
85

            
86
#[async_trait]
87
impl DlDataBufferModel for Model {
88
96
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
89
        let _ = sqlx::query(TABLE_INIT_SQL)
90
            .execute(self.conn.as_ref())
91
            .await?;
92
        Ok(())
93
96
    }
94

            
95
172
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
96
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
97

            
98
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
99
            .fetch_one(self.conn.as_ref())
100
            .await;
101

            
102
        let row = match result {
103
            Err(e) => return Err(Box::new(e)),
104
            Ok(row) => row,
105
        };
106
        Ok(row.count as u64)
107
172
    }
108

            
109
    async fn list(
110
        &self,
111
        opts: &ListOptions,
112
        cursor: Option<Box<dyn Cursor>>,
113
380
    ) -> Result<(Vec<DlDataBuffer>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
114
        let mut cursor = match cursor {
115
            None => Box::new(DbCursor::new()),
116
            Some(cursor) => cursor,
117
        };
118

            
119
        let mut opts = ListOptions { ..*opts };
120
        if let Some(offset) = opts.offset {
121
            opts.offset = Some(offset + cursor.offset());
122
        } else {
123
            opts.offset = Some(cursor.offset());
124
        }
125
        let opts_limit = opts.limit;
126
        if let Some(limit) = opts_limit {
127
            if limit > 0 {
128
                if cursor.offset() >= limit {
129
                    return Ok((vec![], None));
130
                }
131
                opts.limit = Some(limit - cursor.offset());
132
            }
133
        }
134
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
135
        build_limit_offset(&mut builder, &opts);
136
        build_sort(&mut builder, &opts);
137
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
138

            
139
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
140

            
141
        let mut count: u64 = 0;
142
        let mut list = vec![];
143
        while let Some(row) = rows.try_next().await? {
144
            let _ = cursor.as_mut().try_next().await?;
145
            list.push(DlDataBuffer {
146
                data_id: row.data_id,
147
                unit_id: row.unit_id,
148
                unit_code: row.unit_code,
149
                application_id: row.application_id,
150
                application_code: row.application_code,
151
                network_id: row.network_id,
152
                network_addr: row.network_addr,
153
                device_id: row.device_id,
154
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
155
                expired_at: Utc.timestamp_nanos(row.expired_at * 1000000),
156
            });
157
            if let Some(limit) = opts_limit {
158
                if limit > 0 && cursor.offset() >= limit {
159
                    if let Some(cursor_max) = opts.cursor_max {
160
                        if (count + 1) >= cursor_max {
161
                            return Ok((list, Some(cursor)));
162
                        }
163
                    }
164
                    return Ok((list, None));
165
                }
166
            }
167
            if let Some(cursor_max) = opts.cursor_max {
168
                count += 1;
169
                if count >= cursor_max {
170
                    return Ok((list, Some(cursor)));
171
                }
172
            }
173
        }
174
        Ok((list, None))
175
380
    }
176

            
177
386
    async fn get(&self, data_id: &str) -> Result<Option<DlDataBuffer>, Box<dyn StdError>> {
178
        let sql = SqlBuilder::select_from(TABLE_NAME)
179
            .fields(FIELDS)
180
            .and_where_eq("data_id", quote(data_id))
181
            .sql()?;
182

            
183
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
184
            .fetch_one(self.conn.as_ref())
185
            .await;
186

            
187
        let row = match result {
188
            Err(e) => match e {
189
                sqlx::Error::RowNotFound => return Ok(None),
190
                _ => return Err(Box::new(e)),
191
            },
192
            Ok(row) => row,
193
        };
194

            
195
        Ok(Some(DlDataBuffer {
196
            data_id: row.data_id,
197
            unit_id: row.unit_id,
198
            unit_code: row.unit_code,
199
            application_id: row.application_id,
200
            application_code: row.application_code,
201
            network_id: row.network_id,
202
            network_addr: row.network_addr,
203
            device_id: row.device_id,
204
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
205
            expired_at: Utc.timestamp_nanos(row.expired_at * 1000000),
206
        }))
207
386
    }
208

            
209
2284
    async fn add(&self, dldata: &DlDataBuffer) -> Result<(), Box<dyn StdError>> {
210
        let values = vec![
211
            quote(dldata.data_id.as_str()),
212
            quote(dldata.unit_id.as_str()),
213
            quote(dldata.unit_code.as_str()),
214
            quote(dldata.application_id.as_str()),
215
            quote(dldata.application_code.as_str()),
216
            quote(dldata.network_id.as_str()),
217
            quote(dldata.network_addr.as_str()),
218
            quote(dldata.device_id.as_str()),
219
            dldata.created_at.timestamp_millis().to_string(),
220
            dldata.expired_at.timestamp_millis().to_string(),
221
        ];
222
        let sql = SqlBuilder::insert_into(TABLE_NAME)
223
            .fields(FIELDS)
224
            .values(&values)
225
            .sql()?;
226
        let _ = sqlx::query(sql.as_str())
227
            .execute(self.conn.as_ref())
228
            .await?;
229
        Ok(())
230
2284
    }
231

            
232
176
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
233
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
234
        let _ = sqlx::query(sql.as_str())
235
            .execute(self.conn.as_ref())
236
            .await?;
237
        Ok(())
238
176
    }
239
}
240

            
241
impl DbCursor {
242
    /// To create the cursor instance.
243
312
    pub fn new() -> Self {
244
312
        DbCursor { offset: 0 }
245
312
    }
246
}
247

            
248
#[async_trait]
249
impl Cursor for DbCursor {
250
7284
    async fn try_next(&mut self) -> Result<Option<DlDataBuffer>, Box<dyn StdError>> {
251
        self.offset += 1;
252
        Ok(None)
253
7284
    }
254

            
255
5652
    fn offset(&self) -> u64 {
256
5652
        self.offset
257
5652
    }
258
}
259

            
260
/// Transforms query conditions to the SQL builder.
261
176
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
262
176
    if let Some(value) = cond.data_id {
263
28
        builder.and_where_eq("data_id", quote(value));
264
148
    }
265
176
    if let Some(value) = cond.unit_id {
266
60
        builder.and_where_eq("unit_id", quote(value));
267
116
    }
268
176
    if let Some(value) = cond.application_id {
269
22
        builder.and_where_eq("application_id", quote(value));
270
154
    }
271
176
    if let Some(value) = cond.network_id {
272
50
        builder.and_where_eq("network_id", quote(value));
273
126
    }
274
176
    if let Some(value) = cond.network_addrs {
275
16592
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
276
28
        builder.and_where_in("network_addr", &values);
277
148
    }
278
176
    if let Some(value) = cond.device_id {
279
42
        builder.and_where_eq("device_id", quote(value));
280
134
    }
281
176
    builder
282
176
}
283

            
284
/// Transforms query conditions to the SQL builder.
285
518
fn build_list_where<'a>(
286
518
    builder: &'a mut SqlBuilder,
287
518
    cond: &ListQueryCond<'a>,
288
518
) -> &'a mut SqlBuilder {
289
518
    if let Some(value) = cond.unit_id {
290
256
        builder.and_where_eq("unit_id", quote(value));
291
262
    }
292
518
    if let Some(value) = cond.application_id {
293
128
        builder.and_where_eq("application_id", quote(value));
294
390
    }
295
518
    if let Some(value) = cond.network_id {
296
128
        builder.and_where_eq("network_id", quote(value));
297
390
    }
298
518
    if let Some(value) = cond.device_id {
299
52
        builder.and_where_eq("device_id", quote(value));
300
466
    }
301
518
    builder
302
518
}
303

            
304
/// Transforms model options to the SQL builder.
305
346
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
306
346
    if let Some(value) = opts.limit {
307
282
        if value > 0 {
308
280
            builder.limit(value);
309
280
        }
310
64
    }
311
346
    if let Some(value) = opts.offset {
312
346
        match opts.limit {
313
64
            None => builder.limit(-1).offset(value),
314
2
            Some(0) => builder.limit(-1).offset(value),
315
280
            _ => builder.offset(value),
316
        };
317
    }
318
346
    builder
319
346
}
320

            
321
/// Transforms model options to the SQL builder.
322
346
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
323
346
    if let Some(sort_cond) = opts.sort.as_ref() {
324
622
        for cond in sort_cond.iter() {
325
622
            let key = match cond.key {
326
278
                SortKey::CreatedAt => "created_at",
327
28
                SortKey::ExpiredAt => "expired_at",
328
316
                SortKey::ApplicationCode => "application_code",
329
            };
330
622
            builder.order_by(key, !cond.asc);
331
        }
332
10
    }
333
346
    builder
334
346
}