1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::super::network_dldata::{
10
    Cursor, ListOptions, ListQueryCond, NetworkDlData, NetworkDlDataModel, QueryCond, SortKey,
11
    UpdateQueryCond, Updates,
12
};
13

            
14
/// Model instance.
15
pub struct Model {
16
    /// The associated database connection.
17
    conn: Arc<SqlitePool>,
18
}
19

            
20
/// Cursor instance.
21
///
22
/// The SQLite implementation uses the original list options and the progress offset.
23
pub struct DbCursor {
24
    offset: u64,
25
}
26

            
27
/// SQLite schema.
28
#[derive(sqlx::FromRow)]
29
struct Schema {
30
    pub data_id: String,
31
    /// i64 as time tick from Epoch in milliseconds.
32
    pub proc: i64,
33
    /// i64 as time tick from Epoch in milliseconds.
34
    #[sqlx(rename = "pub")]
35
    pub publish: i64,
36
    /// i64 as time tick from Epoch in milliseconds.
37
    pub resp: Option<i64>,
38
    pub status: i32,
39
    pub unit_id: String,
40
    /// use empty string as NULL.
41
    pub device_id: String,
42
    /// use empty string as NULL.
43
    pub network_code: String,
44
    /// use empty string as NULL.
45
    pub network_addr: String,
46
    pub profile: String,
47
    pub data: String,
48
    /// use empty string as NULL.
49
    pub extension: String,
50
}
51

            
52
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
53
#[derive(sqlx::FromRow)]
54
struct CountSchema {
55
    #[sqlx(rename = "COUNT(*)")]
56
    count: i64,
57
}
58

            
59
const TABLE_NAME: &'static str = "network_dldata";
60
const FIELDS: &'static [&'static str] = &[
61
    "data_id",
62
    "proc",
63
    "pub",
64
    "resp",
65
    "status",
66
    "unit_id",
67
    "device_id",
68
    "network_code",
69
    "network_addr",
70
    "profile",
71
    "data",
72
    "extension",
73
];
74
const TABLE_INIT_SQL: &'static str = "\
75
    CREATE TABLE IF NOT EXISTS network_dldata (\
76
    data_id TEXT NOT NULL UNIQUE,\
77
    proc INTEGER NOT NULL,\
78
    pub INTEGER NOT NULL,\
79
    resp INTEGER,\
80
    status INTEGER NOT NULL,\
81
    unit_id TEXT NOT NULL,\
82
    device_id TEXT NOT NULL,\
83
    network_code TEXT NOT NULL,\
84
    network_addr TEXT NOT NULL,\
85
    profile TEXT NOT NULL,\
86
    data TEXT NOT NULL,\
87
    extension TEXT NOT NULL,\
88
    PRIMARY KEY (data_id))";
89

            
90
impl Model {
91
    /// To create the model instance with a database connection.
92
24
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
93
24
        let model = Model { conn };
94
24
        model.init().await?;
95
24
        Ok(model)
96
24
    }
97
}
98

            
99
#[async_trait]
100
impl NetworkDlDataModel for Model {
101
38
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
102
38
        let _ = sqlx::query(TABLE_INIT_SQL)
103
38
            .execute(self.conn.as_ref())
104
38
            .await?;
105
38
        Ok(())
106
76
    }
107

            
108
88
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
109
88
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
110

            
111
88
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
112
88
            .fetch_one(self.conn.as_ref())
113
88
            .await;
114

            
115
88
        let row = match result {
116
            Err(e) => return Err(Box::new(e)),
117
88
            Ok(row) => row,
118
88
        };
119
88
        Ok(row.count as u64)
120
176
    }
121

            
122
    async fn list(
123
        &self,
124
        opts: &ListOptions,
125
        cursor: Option<Box<dyn Cursor>>,
126
206
    ) -> Result<(Vec<NetworkDlData>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
127
206
        let mut cursor = match cursor {
128
180
            None => Box::new(DbCursor::new()),
129
26
            Some(cursor) => cursor,
130
        };
131

            
132
206
        let mut opts = ListOptions { ..*opts };
133
206
        if let Some(offset) = opts.offset {
134
38
            opts.offset = Some(offset + cursor.offset());
135
168
        } else {
136
168
            opts.offset = Some(cursor.offset());
137
168
        }
138
206
        let opts_limit = opts.limit;
139
206
        if let Some(limit) = opts_limit {
140
154
            if limit > 0 {
141
142
                if cursor.offset() >= limit {
142
6
                    return Ok((vec![], None));
143
136
                }
144
136
                opts.limit = Some(limit - cursor.offset());
145
12
            }
146
52
        }
147
200
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
148
200
        build_limit_offset(&mut builder, &opts);
149
200
        build_sort(&mut builder, &opts);
150
200
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
151

            
152
200
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
153
200

            
154
200
        let mut count: u64 = 0;
155
200
        let mut list = vec![];
156
2526
        while let Some(row) = rows.try_next().await? {
157
2370
            let _ = cursor.as_mut().try_next().await?;
158
2370
            list.push(NetworkDlData {
159
2370
                data_id: row.data_id,
160
2370
                proc: Utc.timestamp_nanos(row.proc * 1000000),
161
2370
                publish: Utc.timestamp_nanos(row.publish * 1000000),
162
2370
                resp: match row.resp {
163
1520
                    None => None,
164
850
                    Some(resp) => Some(Utc.timestamp_nanos(resp * 1000000)),
165
                },
166
2370
                status: row.status,
167
2370
                unit_id: row.unit_id,
168
2370
                device_id: row.device_id,
169
2370
                network_code: row.network_code,
170
2370
                network_addr: row.network_addr,
171
2370
                profile: row.profile,
172
2370
                data: row.data,
173
2370
                extension: match row.extension.len() {
174
1856
                    0 => None,
175
514
                    _ => serde_json::from_str(row.extension.as_str())?,
176
                },
177
            });
178
2370
            if let Some(limit) = opts_limit {
179
2198
                if limit > 0 && cursor.offset() >= limit {
180
24
                    if let Some(cursor_max) = opts.cursor_max {
181
22
                        if (count + 1) >= cursor_max {
182
6
                            return Ok((list, Some(cursor)));
183
16
                        }
184
2
                    }
185
18
                    return Ok((list, None));
186
2174
                }
187
172
            }
188
2346
            if let Some(cursor_max) = opts.cursor_max {
189
2156
                count += 1;
190
2156
                if count >= cursor_max {
191
20
                    return Ok((list, Some(cursor)));
192
2136
                }
193
190
            }
194
        }
195
156
        Ok((list, None))
196
412
    }
197

            
198
946
    async fn add(&self, data: &NetworkDlData) -> Result<(), Box<dyn StdError>> {
199
946
        let extension = match data.extension.as_ref() {
200
524
            None => quote(""),
201
422
            Some(extension) => match serde_json::to_string(extension) {
202
                Err(_) => quote("{}"),
203
422
                Ok(value) => quote(value.as_str()),
204
            },
205
        };
206
946
        let values = vec![
207
946
            quote(data.data_id.as_str()),
208
946
            data.proc.timestamp_millis().to_string(),
209
946
            data.publish.timestamp_millis().to_string(),
210
946
            match data.resp {
211
476
                None => "NULL".to_string(),
212
470
                Some(resp) => resp.timestamp_millis().to_string(),
213
            },
214
946
            data.status.to_string(),
215
946
            quote(data.unit_id.as_str()),
216
946
            quote(data.device_id.as_str()),
217
946
            quote(data.network_code.as_str()),
218
946
            quote(data.network_addr.as_str()),
219
946
            quote(data.profile.as_str()),
220
946
            quote(data.data.as_str()),
221
946
            extension,
222
        ];
223
946
        let sql = SqlBuilder::insert_into(TABLE_NAME)
224
946
            .fields(FIELDS)
225
946
            .values(&values)
226
946
            .sql()?;
227
946
        let _ = sqlx::query(sql.as_str())
228
946
            .execute(self.conn.as_ref())
229
946
            .await?;
230
944
        Ok(())
231
1892
    }
232

            
233
82
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
234
82
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
235
82
        let _ = sqlx::query(sql.as_str())
236
82
            .execute(self.conn.as_ref())
237
82
            .await?;
238
82
        Ok(())
239
164
    }
240

            
241
    async fn update(
242
        &self,
243
        cond: &UpdateQueryCond,
244
        updates: &Updates,
245
34
    ) -> Result<(), Box<dyn StdError>> {
246
34
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
247
        {
248
            None => return Ok(()),
249
34
            Some(builder) => builder.sql()?,
250
        };
251
34
        let _ = sqlx::query(sql.as_str())
252
34
            .execute(self.conn.as_ref())
253
34
            .await?;
254
34
        Ok(())
255
68
    }
256
}
257

            
258
impl DbCursor {
259
    /// To create the cursor instance.
260
180
    pub fn new() -> Self {
261
180
        DbCursor { offset: 0 }
262
180
    }
263
}
264

            
265
#[async_trait]
266
impl Cursor for DbCursor {
267
2370
    async fn try_next(&mut self) -> Result<Option<NetworkDlData>, Box<dyn StdError>> {
268
2370
        self.offset += 1;
269
2370
        Ok(None)
270
4740
    }
271

            
272
1884
    fn offset(&self) -> u64 {
273
1884
        self.offset
274
1884
    }
275
}
276

            
277
/// Transforms query conditions to the SQL builder.
278
82
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
279
82
    if let Some(value) = cond.unit_id {
280
6
        builder.and_where_eq("unit_id", quote(value));
281
76
    }
282
82
    if let Some(value) = cond.device_id {
283
2
        builder.and_where_eq("device_id", quote(value));
284
80
    }
285
82
    if let Some(value) = cond.proc_gte {
286
2
        builder.and_where_ge("proc", value.timestamp_millis());
287
80
    }
288
82
    if let Some(value) = cond.proc_lte {
289
2
        builder.and_where_le("proc", value.timestamp_millis());
290
80
    }
291
82
    builder
292
82
}
293

            
294
/// Transforms query conditions to the SQL builder.
295
288
fn build_list_where<'a>(
296
288
    builder: &'a mut SqlBuilder,
297
288
    cond: &ListQueryCond<'a>,
298
288
) -> &'a mut SqlBuilder {
299
288
    if let Some(value) = cond.unit_id {
300
116
        builder.and_where_eq("unit_id", quote(value));
301
172
    }
302
288
    if let Some(value) = cond.device_id {
303
28
        builder.and_where_eq("device_id", quote(value));
304
260
    }
305
288
    if let Some(value) = cond.network_code {
306
16
        builder.and_where_eq("network_code", quote(value));
307
272
    }
308
288
    if let Some(value) = cond.network_addr {
309
12
        builder.and_where_eq("network_addr", quote(value));
310
276
    }
311
288
    if let Some(value) = cond.profile {
312
12
        builder.and_where_eq("profile", quote(value));
313
276
    }
314
288
    if let Some(value) = cond.proc_gte {
315
62
        builder.and_where_ge("proc", value.timestamp_millis());
316
226
    }
317
288
    if let Some(value) = cond.proc_lte {
318
16
        builder.and_where_le("proc", value.timestamp_millis());
319
272
    }
320
288
    if let Some(value) = cond.pub_gte {
321
28
        builder.and_where_ge("pub", value.timestamp_millis());
322
260
    }
323
288
    if let Some(value) = cond.pub_lte {
324
16
        builder.and_where_le("pub", value.timestamp_millis());
325
272
    }
326
288
    if let Some(value) = cond.resp_gte {
327
28
        builder.and_where_ge("resp", value.timestamp_millis());
328
260
    }
329
288
    if let Some(value) = cond.resp_lte {
330
16
        builder.and_where_le("resp", value.timestamp_millis());
331
272
    }
332
288
    builder
333
288
}
334

            
335
/// Transforms model options to the SQL builder.
336
200
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
337
200
    if let Some(value) = opts.limit {
338
148
        if value > 0 {
339
136
            builder.limit(value);
340
136
        }
341
52
    }
342
200
    if let Some(value) = opts.offset {
343
200
        match opts.limit {
344
52
            None => builder.limit(-1).offset(value),
345
12
            Some(0) => builder.limit(-1).offset(value),
346
136
            _ => builder.offset(value),
347
        };
348
    }
349
200
    builder
350
200
}
351

            
352
/// Transforms model options to the SQL builder.
353
200
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
354
200
    if let Some(sort_cond) = opts.sort.as_ref() {
355
206
        for cond in sort_cond.iter() {
356
206
            let key = match cond.key {
357
158
                SortKey::Proc => "proc",
358
8
                SortKey::Pub => "pub",
359
8
                SortKey::Resp => "resp",
360
16
                SortKey::NetworkCode => "network_code",
361
16
                SortKey::NetworkAddr => "network_addr",
362
            };
363
206
            builder.order_by(key, !cond.asc);
364
        }
365
26
    }
366
200
    builder
367
200
}
368

            
369
/// Transforms query conditions and the model object to the SQL builder.
370
34
fn build_update_where<'a>(
371
34
    builder: &'a mut SqlBuilder,
372
34
    cond: &UpdateQueryCond<'a>,
373
34
    updates: &Updates,
374
34
) -> Option<&'a mut SqlBuilder> {
375
34
    builder.set("resp", updates.resp.timestamp_millis());
376
34
    builder.set("status", updates.status);
377
34
    builder.and_where_eq("data_id", quote(cond.data_id));
378
34
    if updates.status >= 0 {
379
26
        builder.and_where_ne("status", 0);
380
26
    } else {
381
8
        builder.and_where_lt("status", updates.status);
382
8
    }
383
34
    Some(builder)
384
34
}