1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::{
10
    super::device::{
11
        Cursor, Device, DeviceModel, ListOptions, ListQueryCond, QueryCond, SortKey,
12
        UpdateQueryCond, Updates,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    device_id: String,
34
    unit_id: String,
35
    /// use empty string as NULL because duplicate
36
    /// `(unit_code=NULL,network_code="code",network_addr="addr")` is allowed.
37
    unit_code: String,
38
    network_id: String,
39
    network_code: String,
40
    network_addr: String,
41
    /// i64 as time tick from Epoch in milliseconds.
42
    created_at: i64,
43
    /// i64 as time tick from Epoch in milliseconds.
44
    modified_at: i64,
45
    profile: String,
46
    name: String,
47
    info: String,
48
}
49

            
50
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
51
#[derive(sqlx::FromRow)]
52
struct CountSchema {
53
    #[sqlx(rename = "COUNT(*)")]
54
    count: i64,
55
}
56

            
57
const TABLE_NAME: &'static str = "device";
58
const FIELDS: &'static [&'static str] = &[
59
    "device_id",
60
    "unit_id",
61
    "unit_code",
62
    "network_id",
63
    "network_code",
64
    "network_addr",
65
    "created_at",
66
    "modified_at",
67
    "profile",
68
    "name",
69
    "info",
70
];
71
const TABLE_INIT_SQL: &'static str = "\
72
    CREATE TABLE IF NOT EXISTS device (\
73
    device_id TEXT NOT NULL UNIQUE,\
74
    unit_id TEXT NOT NULL,\
75
    unit_code TEXT NOT NULL,\
76
    network_id TEXT NOT NULL,\
77
    network_code TEXT NOT NULL,\
78
    network_addr TEXT NOT NULL,\
79
    created_at INTEGER NOT NULL,\
80
    modified_at INTEGER NOT NULL,\
81
    profile TEXT NOT NULL,\
82
    name TEXT NOT NULL,\
83
    info TEXT,\
84
    UNIQUE (unit_code,network_code,network_addr),\
85
    PRIMARY KEY (device_id))";
86

            
87
impl Model {
88
    /// To create the model instance with a database connection.
89
56
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
90
56
        let model = Model { conn };
91
56
        model.init().await?;
92
56
        Ok(model)
93
56
    }
94
}
95

            
96
#[async_trait]
97
impl DeviceModel for Model {
98
96
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
99
96
        let _ = sqlx::query(TABLE_INIT_SQL)
100
96
            .execute(self.conn.as_ref())
101
96
            .await?;
102
96
        Ok(())
103
192
    }
104

            
105
158
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
106
158
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
107

            
108
158
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
109
158
            .fetch_one(self.conn.as_ref())
110
158
            .await;
111

            
112
158
        let row = match result {
113
            Err(e) => return Err(Box::new(e)),
114
158
            Ok(row) => row,
115
158
        };
116
158
        Ok(row.count as u64)
117
316
    }
118

            
119
    async fn list(
120
        &self,
121
        opts: &ListOptions,
122
        cursor: Option<Box<dyn Cursor>>,
123
738
    ) -> Result<(Vec<Device>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
124
738
        let mut cursor = match cursor {
125
654
            None => Box::new(DbCursor::new()),
126
84
            Some(cursor) => cursor,
127
        };
128

            
129
738
        let mut opts = ListOptions { ..*opts };
130
738
        if let Some(offset) = opts.offset {
131
48
            opts.offset = Some(offset + cursor.offset());
132
690
        } else {
133
690
            opts.offset = Some(cursor.offset());
134
690
        }
135
738
        let opts_limit = opts.limit;
136
738
        if let Some(limit) = opts_limit {
137
284
            if limit > 0 {
138
282
                if cursor.offset() >= limit {
139
26
                    return Ok((vec![], None));
140
256
                }
141
256
                opts.limit = Some(limit - cursor.offset());
142
2
            }
143
454
        }
144
712
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
145
712
        build_limit_offset(&mut builder, &opts);
146
712
        build_sort(&mut builder, &opts);
147
712
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
148

            
149
712
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
150
712

            
151
712
        let mut count: u64 = 0;
152
712
        let mut list = vec![];
153
149332
        while let Some(row) = rows.try_next().await? {
154
148724
            let _ = cursor.as_mut().try_next().await?;
155
148724
            list.push(Device {
156
148724
                device_id: row.device_id,
157
148724
                unit_id: row.unit_id,
158
148724
                unit_code: match row.unit_code.len() {
159
43520
                    0 => None,
160
105204
                    _ => Some(row.unit_code),
161
                },
162
148724
                network_id: row.network_id,
163
148724
                network_code: row.network_code,
164
148724
                network_addr: row.network_addr,
165
148724
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
166
148724
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
167
148724
                profile: row.profile,
168
148724
                name: row.name,
169
148724
                info: serde_json::from_str(row.info.as_str())?,
170
            });
171
148724
            if let Some(limit) = opts_limit {
172
4074
                if limit > 0 && cursor.offset() >= limit {
173
46
                    if let Some(cursor_max) = opts.cursor_max {
174
44
                        if (count + 1) >= cursor_max {
175
26
                            return Ok((list, Some(cursor)));
176
18
                        }
177
2
                    }
178
20
                    return Ok((list, None));
179
4028
                }
180
144650
            }
181
148678
            if let Some(cursor_max) = opts.cursor_max {
182
8842
                count += 1;
183
8842
                if count >= cursor_max {
184
58
                    return Ok((list, Some(cursor)));
185
8784
                }
186
139836
            }
187
        }
188
608
        Ok((list, None))
189
1476
    }
190

            
191
982
    async fn get(&self, cond: &QueryCond) -> Result<Option<Device>, Box<dyn StdError>> {
192
982
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
193

            
194
982
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
195
982
            .fetch_one(self.conn.as_ref())
196
982
            .await;
197

            
198
982
        let row = match result {
199
156
            Err(e) => match e {
200
156
                sqlx::Error::RowNotFound => return Ok(None),
201
                _ => return Err(Box::new(e)),
202
            },
203
826
            Ok(row) => row,
204
826
        };
205
826

            
206
826
        Ok(Some(Device {
207
826
            device_id: row.device_id,
208
826
            unit_id: row.unit_id,
209
826
            unit_code: match row.unit_code.len() {
210
388
                0 => None,
211
438
                _ => Some(row.unit_code),
212
            },
213
826
            network_id: row.network_id,
214
826
            network_code: row.network_code,
215
826
            network_addr: row.network_addr,
216
826
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
217
826
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
218
826
            profile: row.profile,
219
826
            name: row.name,
220
826
            info: serde_json::from_str(row.info.as_str())?,
221
        }))
222
1964
    }
223

            
224
4992
    async fn add(&self, device: &Device) -> Result<(), Box<dyn StdError>> {
225
4992
        let unit_code = match device.unit_code.as_deref() {
226
3644
            None => quote(""),
227
1348
            Some(value) => quote(value),
228
        };
229
4992
        let info = match serde_json::to_string(&device.info) {
230
            Err(_) => quote("{}"),
231
4992
            Ok(value) => quote(value.as_str()),
232
        };
233
4992
        let values = vec![
234
4992
            quote(device.device_id.as_str()),
235
4992
            quote(device.unit_id.as_str()),
236
4992
            unit_code,
237
4992
            quote(device.network_id.as_str()),
238
4992
            quote(device.network_code.as_str()),
239
4992
            quote(device.network_addr.as_str()),
240
4992
            device.created_at.timestamp_millis().to_string(),
241
4992
            device.modified_at.timestamp_millis().to_string(),
242
4992
            quote(device.profile.as_str()),
243
4992
            quote(device.name.as_str()),
244
4992
            info,
245
4992
        ];
246
4992
        let sql = SqlBuilder::insert_into(TABLE_NAME)
247
4992
            .fields(FIELDS)
248
4992
            .values(&values)
249
4992
            .sql()?;
250
4992
        let _ = sqlx::query(sql.as_str())
251
4992
            .execute(self.conn.as_ref())
252
4992
            .await?;
253
4988
        Ok(())
254
9984
    }
255

            
256
142
    async fn add_bulk(&self, devices: &Vec<Device>) -> Result<(), Box<dyn StdError>> {
257
142
        let mut builder = SqlBuilder::insert_into(TABLE_NAME);
258
142
        builder.fields(FIELDS);
259

            
260
82610
        for device in devices.iter() {
261
82610
            let unit_code = match device.unit_code.as_deref() {
262
25184
                None => quote(""),
263
57426
                Some(value) => quote(value),
264
            };
265
82610
            let info = match serde_json::to_string(&device.info) {
266
                Err(_) => quote("{}"),
267
82610
                Ok(value) => quote(value.as_str()),
268
            };
269

            
270
82610
            builder.values(&vec![
271
82610
                quote(device.device_id.as_str()),
272
82610
                quote(device.unit_id.as_str()),
273
82610
                unit_code,
274
82610
                quote(device.network_id.as_str()),
275
82610
                quote(device.network_code.as_str()),
276
82610
                quote(device.network_addr.as_str()),
277
82610
                device.created_at.timestamp_millis().to_string(),
278
82610
                device.modified_at.timestamp_millis().to_string(),
279
82610
                quote(device.profile.as_str()),
280
82610
                quote(device.name.as_str()),
281
82610
                info,
282
82610
            ]);
283
        }
284
142
        let sql = builder.sql()?.replace(");", ") ON CONFLICT DO NOTHING;");
285
142
        let _ = sqlx::query(sql.as_str())
286
142
            .execute(self.conn.as_ref())
287
142
            .await?;
288
142
        Ok(())
289
284
    }
290

            
291
212
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
292
212
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
293
212
        let _ = sqlx::query(sql.as_str())
294
212
            .execute(self.conn.as_ref())
295
212
            .await?;
296
212
        Ok(())
297
424
    }
298

            
299
    async fn update(
300
        &self,
301
        cond: &UpdateQueryCond,
302
        updates: &Updates,
303
66
    ) -> Result<(), Box<dyn StdError>> {
304
66
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
305
        {
306
2
            None => return Ok(()),
307
64
            Some(builder) => builder.sql()?,
308
        };
309
64
        let _ = sqlx::query(sql.as_str())
310
64
            .execute(self.conn.as_ref())
311
64
            .await?;
312
64
        Ok(())
313
132
    }
314
}
315

            
316
impl DbCursor {
317
    /// To create the cursor instance.
318
654
    pub fn new() -> Self {
319
654
        DbCursor { offset: 0 }
320
654
    }
321
}
322

            
323
#[async_trait]
324
impl Cursor for DbCursor {
325
148724
    async fn try_next(&mut self) -> Result<Option<Device>, Box<dyn StdError>> {
326
148724
        self.offset += 1;
327
148724
        Ok(None)
328
297448
    }
329

            
330
5350
    fn offset(&self) -> u64 {
331
5350
        self.offset
332
5350
    }
333
}
334

            
335
/// Transforms query conditions to the SQL builder.
336
1194
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
337
1194
    if let Some(value) = cond.unit_id {
338
126
        builder.and_where_eq("unit_id", quote(value));
339
1068
    }
340
1194
    if let Some(value) = cond.device_id {
341
960
        builder.and_where_eq("device_id", quote(value));
342
960
    }
343
1194
    if let Some(value) = cond.network_id {
344
70
        builder.and_where_eq("network_id", quote(value));
345
1124
    }
346
1194
    if let Some(value) = cond.network_addrs {
347
32992
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
348
44
        builder.and_where_in("network_addr", &values);
349
1150
    }
350
1194
    if let Some(value) = cond.device.as_ref() {
351
118
        if let Some(unit_code) = value.unit_code {
352
88
            builder.and_where_eq("unit_code", quote(unit_code));
353
88
        } else {
354
30
            builder.and_where_eq("unit_code", quote(""));
355
30
        }
356
118
        builder.and_where_eq("network_code", quote(value.network_code));
357
118
        builder.and_where_eq("network_addr", quote(value.network_addr));
358
1076
    }
359
1194
    builder
360
1194
}
361

            
362
/// Transforms query conditions to the SQL builder.
363
870
fn build_list_where<'a>(
364
870
    builder: &'a mut SqlBuilder,
365
870
    cond: &ListQueryCond<'a>,
366
870
) -> &'a mut SqlBuilder {
367
870
    if let Some(value) = cond.unit_id {
368
344
        builder.and_where_eq("unit_id", quote(value));
369
526
    }
370
870
    if let Some(value) = cond.device_id {
371
20
        builder.and_where_eq("device_id", quote(value));
372
850
    }
373
870
    if let Some(value) = cond.network_id {
374
324
        builder.and_where_eq("network_id", quote(value));
375
546
    }
376
870
    if let Some(value) = cond.network_code {
377
92
        builder.and_where_eq("network_code", quote(value));
378
778
    }
379
870
    if let Some(value) = cond.network_addr {
380
204
        builder.and_where_eq("network_addr", quote(value));
381
666
    } else if let Some(value) = cond.network_addrs {
382
73780
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
383
108
        builder.and_where_in("network_addr", &values);
384
558
    }
385
870
    if let Some(value) = cond.profile {
386
28
        build_where_like(builder, "profile", value.to_lowercase().as_str());
387
842
    }
388
870
    if let Some(value) = cond.name_contains {
389
144
        build_where_like(builder, "name", value.to_lowercase().as_str());
390
726
    }
391
870
    builder
392
870
}
393

            
394
/// Transforms model options to the SQL builder.
395
712
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
396
712
    if let Some(value) = opts.limit {
397
258
        if value > 0 {
398
256
            builder.limit(value);
399
256
        }
400
454
    }
401
712
    if let Some(value) = opts.offset {
402
712
        match opts.limit {
403
454
            None => builder.limit(-1).offset(value),
404
2
            Some(0) => builder.limit(-1).offset(value),
405
256
            _ => builder.offset(value),
406
        };
407
    }
408
712
    builder
409
712
}
410

            
411
/// Transforms model options to the SQL builder.
412
712
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
413
712
    if let Some(sort_cond) = opts.sort.as_ref() {
414
724
        for cond in sort_cond.iter() {
415
724
            let key = match cond.key {
416
16
                SortKey::CreatedAt => "created_at",
417
12
                SortKey::ModifiedAt => "modified_at",
418
280
                SortKey::NetworkCode => "network_code",
419
396
                SortKey::NetworkAddr => "network_addr",
420
4
                SortKey::Profile => "profile",
421
16
                SortKey::Name => "name",
422
            };
423
724
            builder.order_by(key, !cond.asc);
424
        }
425
270
    }
426
712
    builder
427
712
}
428

            
429
/// Transforms query conditions and the model object to the SQL builder.
430
66
fn build_update_where<'a>(
431
66
    builder: &'a mut SqlBuilder,
432
66
    cond: &UpdateQueryCond<'a>,
433
66
    updates: &Updates,
434
66
) -> Option<&'a mut SqlBuilder> {
435
66
    let mut count = 0;
436
66
    if let Some((network_id, network_code)) = updates.network {
437
20
        builder.set("network_id", quote(network_id));
438
20
        builder.set("network_code", quote(network_code));
439
20
        count += 1;
440
46
    }
441
66
    if let Some(value) = updates.network_addr {
442
28
        builder.set("network_addr", quote(value));
443
28
        count += 1;
444
38
    }
445
66
    if let Some(value) = updates.modified_at.as_ref() {
446
64
        builder.set("modified_at", value.timestamp_millis());
447
64
        count += 1;
448
64
    }
449
66
    if let Some(value) = updates.profile.as_ref() {
450
40
        builder.set("profile", quote(value));
451
40
        count += 1;
452
40
    }
453
66
    if let Some(value) = updates.name.as_ref() {
454
40
        builder.set("name", quote(value));
455
40
        count += 1;
456
40
    }
457
66
    if let Some(value) = updates.info {
458
36
        match serde_json::to_string(value) {
459
            Err(_) => {
460
                builder.set("info", quote("{}"));
461
            }
462
36
            Ok(value) => {
463
36
                builder.set("info", quote(value));
464
36
            }
465
        }
466
36
        count += 1;
467
30
    }
468
66
    if count == 0 {
469
2
        return None;
470
64
    }
471
64

            
472
64
    builder.and_where_eq("device_id", quote(cond.device_id));
473
64
    Some(builder)
474
66
}