1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::{
10
    super::device::{
11
        Cursor, Device, DeviceModel, ListOptions, ListQueryCond, QueryCond, SortKey,
12
        UpdateQueryCond, Updates,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    device_id: String,
34
    unit_id: String,
35
    /// use empty string as NULL because duplicate
36
    /// `(unit_code=NULL,network_code="code",network_addr="addr")` is allowed.
37
    unit_code: String,
38
    network_id: String,
39
    network_code: String,
40
    network_addr: String,
41
    /// i64 as time tick from Epoch in milliseconds.
42
    created_at: i64,
43
    /// i64 as time tick from Epoch in milliseconds.
44
    modified_at: i64,
45
    profile: String,
46
    name: String,
47
    info: String,
48
}
49

            
50
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
51
#[derive(sqlx::FromRow)]
52
struct CountSchema {
53
    #[sqlx(rename = "COUNT(*)")]
54
    count: i64,
55
}
56

            
57
const TABLE_NAME: &'static str = "device";
58
const FIELDS: &'static [&'static str] = &[
59
    "device_id",
60
    "unit_id",
61
    "unit_code",
62
    "network_id",
63
    "network_code",
64
    "network_addr",
65
    "created_at",
66
    "modified_at",
67
    "profile",
68
    "name",
69
    "info",
70
];
71
const TABLE_INIT_SQL: &'static str = "\
72
    CREATE TABLE IF NOT EXISTS device (\
73
    device_id TEXT NOT NULL UNIQUE,\
74
    unit_id TEXT NOT NULL,\
75
    unit_code TEXT NOT NULL,\
76
    network_id TEXT NOT NULL,\
77
    network_code TEXT NOT NULL,\
78
    network_addr TEXT NOT NULL,\
79
    created_at INTEGER NOT NULL,\
80
    modified_at INTEGER NOT NULL,\
81
    profile TEXT NOT NULL,\
82
    name TEXT NOT NULL,\
83
    info TEXT,\
84
    UNIQUE (unit_code,network_code,network_addr),\
85
    PRIMARY KEY (device_id))";
86

            
87
impl Model {
88
    /// To create the model instance with a database connection.
89
28
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
90
28
        let model = Model { conn };
91
28
        model.init().await?;
92
28
        Ok(model)
93
28
    }
94
}
95

            
96
#[async_trait]
97
impl DeviceModel for Model {
98
48
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
99
48
        let _ = sqlx::query(TABLE_INIT_SQL)
100
48
            .execute(self.conn.as_ref())
101
48
            .await?;
102
48
        Ok(())
103
96
    }
104

            
105
79
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
106
79
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
107

            
108
79
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
109
79
            .fetch_one(self.conn.as_ref())
110
79
            .await;
111

            
112
79
        let row = match result {
113
            Err(e) => return Err(Box::new(e)),
114
79
            Ok(row) => row,
115
79
        };
116
79
        Ok(row.count as u64)
117
158
    }
118

            
119
    async fn list(
120
        &self,
121
        opts: &ListOptions,
122
        cursor: Option<Box<dyn Cursor>>,
123
369
    ) -> Result<(Vec<Device>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
124
369
        let mut cursor = match cursor {
125
327
            None => Box::new(DbCursor::new()),
126
42
            Some(cursor) => cursor,
127
        };
128

            
129
369
        let mut opts = ListOptions { ..*opts };
130
369
        if let Some(offset) = opts.offset {
131
24
            opts.offset = Some(offset + cursor.offset());
132
345
        } else {
133
345
            opts.offset = Some(cursor.offset());
134
345
        }
135
369
        let opts_limit = opts.limit;
136
369
        if let Some(limit) = opts_limit {
137
142
            if limit > 0 {
138
141
                if cursor.offset() >= limit {
139
13
                    return Ok((vec![], None));
140
128
                }
141
128
                opts.limit = Some(limit - cursor.offset());
142
1
            }
143
227
        }
144
356
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
145
356
        build_limit_offset(&mut builder, &opts);
146
356
        build_sort(&mut builder, &opts);
147
356
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
148

            
149
356
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
150
356

            
151
356
        let mut count: u64 = 0;
152
356
        let mut list = vec![];
153
74666
        while let Some(row) = rows.try_next().await? {
154
74362
            let _ = cursor.as_mut().try_next().await?;
155
74362
            list.push(Device {
156
74362
                device_id: row.device_id,
157
74362
                unit_id: row.unit_id,
158
74362
                unit_code: match row.unit_code.len() {
159
21760
                    0 => None,
160
52602
                    _ => Some(row.unit_code),
161
                },
162
74362
                network_id: row.network_id,
163
74362
                network_code: row.network_code,
164
74362
                network_addr: row.network_addr,
165
74362
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
166
74362
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
167
74362
                profile: row.profile,
168
74362
                name: row.name,
169
74362
                info: serde_json::from_str(row.info.as_str())?,
170
            });
171
74362
            if let Some(limit) = opts_limit {
172
2037
                if limit > 0 && cursor.offset() >= limit {
173
23
                    if let Some(cursor_max) = opts.cursor_max {
174
22
                        if (count + 1) >= cursor_max {
175
13
                            return Ok((list, Some(cursor)));
176
9
                        }
177
1
                    }
178
10
                    return Ok((list, None));
179
2014
                }
180
72325
            }
181
74339
            if let Some(cursor_max) = opts.cursor_max {
182
4421
                count += 1;
183
4421
                if count >= cursor_max {
184
29
                    return Ok((list, Some(cursor)));
185
4392
                }
186
69918
            }
187
        }
188
304
        Ok((list, None))
189
738
    }
190

            
191
491
    async fn get(&self, cond: &QueryCond) -> Result<Option<Device>, Box<dyn StdError>> {
192
491
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
193

            
194
491
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
195
491
            .fetch_one(self.conn.as_ref())
196
491
            .await;
197

            
198
491
        let row = match result {
199
78
            Err(e) => match e {
200
78
                sqlx::Error::RowNotFound => return Ok(None),
201
                _ => return Err(Box::new(e)),
202
            },
203
413
            Ok(row) => row,
204
413
        };
205
413

            
206
413
        Ok(Some(Device {
207
413
            device_id: row.device_id,
208
413
            unit_id: row.unit_id,
209
413
            unit_code: match row.unit_code.len() {
210
194
                0 => None,
211
219
                _ => Some(row.unit_code),
212
            },
213
413
            network_id: row.network_id,
214
413
            network_code: row.network_code,
215
413
            network_addr: row.network_addr,
216
413
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
217
413
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
218
413
            profile: row.profile,
219
413
            name: row.name,
220
413
            info: serde_json::from_str(row.info.as_str())?,
221
        }))
222
982
    }
223

            
224
2496
    async fn add(&self, device: &Device) -> Result<(), Box<dyn StdError>> {
225
2496
        let unit_code = match device.unit_code.as_deref() {
226
1822
            None => quote(""),
227
674
            Some(value) => quote(value),
228
        };
229
2496
        let info = match serde_json::to_string(&device.info) {
230
            Err(_) => quote("{}"),
231
2496
            Ok(value) => quote(value.as_str()),
232
        };
233
2496
        let values = vec![
234
2496
            quote(device.device_id.as_str()),
235
2496
            quote(device.unit_id.as_str()),
236
2496
            unit_code,
237
2496
            quote(device.network_id.as_str()),
238
2496
            quote(device.network_code.as_str()),
239
2496
            quote(device.network_addr.as_str()),
240
2496
            device.created_at.timestamp_millis().to_string(),
241
2496
            device.modified_at.timestamp_millis().to_string(),
242
2496
            quote(device.profile.as_str()),
243
2496
            quote(device.name.as_str()),
244
2496
            info,
245
2496
        ];
246
2496
        let sql = SqlBuilder::insert_into(TABLE_NAME)
247
2496
            .fields(FIELDS)
248
2496
            .values(&values)
249
2496
            .sql()?;
250
2496
        let _ = sqlx::query(sql.as_str())
251
2496
            .execute(self.conn.as_ref())
252
2496
            .await?;
253
2494
        Ok(())
254
4992
    }
255

            
256
71
    async fn add_bulk(&self, devices: &Vec<Device>) -> Result<(), Box<dyn StdError>> {
257
71
        let mut builder = SqlBuilder::insert_into(TABLE_NAME);
258
71
        builder.fields(FIELDS);
259

            
260
41305
        for device in devices.iter() {
261
41305
            let unit_code = match device.unit_code.as_deref() {
262
12592
                None => quote(""),
263
28713
                Some(value) => quote(value),
264
            };
265
41305
            let info = match serde_json::to_string(&device.info) {
266
                Err(_) => quote("{}"),
267
41305
                Ok(value) => quote(value.as_str()),
268
            };
269

            
270
41305
            builder.values(&vec![
271
41305
                quote(device.device_id.as_str()),
272
41305
                quote(device.unit_id.as_str()),
273
41305
                unit_code,
274
41305
                quote(device.network_id.as_str()),
275
41305
                quote(device.network_code.as_str()),
276
41305
                quote(device.network_addr.as_str()),
277
41305
                device.created_at.timestamp_millis().to_string(),
278
41305
                device.modified_at.timestamp_millis().to_string(),
279
41305
                quote(device.profile.as_str()),
280
41305
                quote(device.name.as_str()),
281
41305
                info,
282
41305
            ]);
283
        }
284
71
        let sql = builder.sql()?.replace(");", ") ON CONFLICT DO NOTHING;");
285
71
        let _ = sqlx::query(sql.as_str())
286
71
            .execute(self.conn.as_ref())
287
71
            .await?;
288
71
        Ok(())
289
142
    }
290

            
291
106
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
292
106
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
293
106
        let _ = sqlx::query(sql.as_str())
294
106
            .execute(self.conn.as_ref())
295
106
            .await?;
296
106
        Ok(())
297
212
    }
298

            
299
    async fn update(
300
        &self,
301
        cond: &UpdateQueryCond,
302
        updates: &Updates,
303
33
    ) -> Result<(), Box<dyn StdError>> {
304
33
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
305
        {
306
1
            None => return Ok(()),
307
32
            Some(builder) => builder.sql()?,
308
        };
309
32
        let _ = sqlx::query(sql.as_str())
310
32
            .execute(self.conn.as_ref())
311
32
            .await?;
312
32
        Ok(())
313
66
    }
314
}
315

            
316
impl DbCursor {
317
    /// To create the cursor instance.
318
327
    pub fn new() -> Self {
319
327
        DbCursor { offset: 0 }
320
327
    }
321
}
322

            
323
#[async_trait]
324
impl Cursor for DbCursor {
325
74362
    async fn try_next(&mut self) -> Result<Option<Device>, Box<dyn StdError>> {
326
74362
        self.offset += 1;
327
74362
        Ok(None)
328
148724
    }
329

            
330
2675
    fn offset(&self) -> u64 {
331
2675
        self.offset
332
2675
    }
333
}
334

            
335
/// Transforms query conditions to the SQL builder.
336
597
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
337
597
    if let Some(value) = cond.unit_id {
338
63
        builder.and_where_eq("unit_id", quote(value));
339
534
    }
340
597
    if let Some(value) = cond.device_id {
341
480
        builder.and_where_eq("device_id", quote(value));
342
480
    }
343
597
    if let Some(value) = cond.network_id {
344
35
        builder.and_where_eq("network_id", quote(value));
345
562
    }
346
597
    if let Some(value) = cond.network_addrs {
347
16496
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
348
22
        builder.and_where_in("network_addr", &values);
349
575
    }
350
597
    if let Some(value) = cond.device.as_ref() {
351
59
        if let Some(unit_code) = value.unit_code {
352
44
            builder.and_where_eq("unit_code", quote(unit_code));
353
44
        } else {
354
15
            builder.and_where_eq("unit_code", quote(""));
355
15
        }
356
59
        builder.and_where_eq("network_code", quote(value.network_code));
357
59
        builder.and_where_eq("network_addr", quote(value.network_addr));
358
538
    }
359
597
    builder
360
597
}
361

            
362
/// Transforms query conditions to the SQL builder.
363
435
fn build_list_where<'a>(
364
435
    builder: &'a mut SqlBuilder,
365
435
    cond: &ListQueryCond<'a>,
366
435
) -> &'a mut SqlBuilder {
367
435
    if let Some(value) = cond.unit_id {
368
172
        builder.and_where_eq("unit_id", quote(value));
369
263
    }
370
435
    if let Some(value) = cond.device_id {
371
10
        builder.and_where_eq("device_id", quote(value));
372
425
    }
373
435
    if let Some(value) = cond.network_id {
374
162
        builder.and_where_eq("network_id", quote(value));
375
273
    }
376
435
    if let Some(value) = cond.network_code {
377
46
        builder.and_where_eq("network_code", quote(value));
378
389
    }
379
435
    if let Some(value) = cond.network_addr {
380
102
        builder.and_where_eq("network_addr", quote(value));
381
333
    } else if let Some(value) = cond.network_addrs {
382
36890
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
383
54
        builder.and_where_in("network_addr", &values);
384
279
    }
385
435
    if let Some(value) = cond.profile {
386
14
        build_where_like(builder, "profile", value.to_lowercase().as_str());
387
421
    }
388
435
    if let Some(value) = cond.name_contains {
389
72
        build_where_like(builder, "name", value.to_lowercase().as_str());
390
363
    }
391
435
    builder
392
435
}
393

            
394
/// Transforms model options to the SQL builder.
395
356
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
396
356
    if let Some(value) = opts.limit {
397
129
        if value > 0 {
398
128
            builder.limit(value);
399
128
        }
400
227
    }
401
356
    if let Some(value) = opts.offset {
402
356
        match opts.limit {
403
227
            None => builder.limit(-1).offset(value),
404
1
            Some(0) => builder.limit(-1).offset(value),
405
128
            _ => builder.offset(value),
406
        };
407
    }
408
356
    builder
409
356
}
410

            
411
/// Transforms model options to the SQL builder.
412
356
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
413
356
    if let Some(sort_cond) = opts.sort.as_ref() {
414
362
        for cond in sort_cond.iter() {
415
362
            let key = match cond.key {
416
8
                SortKey::CreatedAt => "created_at",
417
6
                SortKey::ModifiedAt => "modified_at",
418
140
                SortKey::NetworkCode => "network_code",
419
198
                SortKey::NetworkAddr => "network_addr",
420
2
                SortKey::Profile => "profile",
421
8
                SortKey::Name => "name",
422
            };
423
362
            builder.order_by(key, !cond.asc);
424
        }
425
135
    }
426
356
    builder
427
356
}
428

            
429
/// Transforms query conditions and the model object to the SQL builder.
430
33
fn build_update_where<'a>(
431
33
    builder: &'a mut SqlBuilder,
432
33
    cond: &UpdateQueryCond<'a>,
433
33
    updates: &Updates,
434
33
) -> Option<&'a mut SqlBuilder> {
435
33
    let mut count = 0;
436
33
    if let Some((network_id, network_code)) = updates.network {
437
10
        builder.set("network_id", quote(network_id));
438
10
        builder.set("network_code", quote(network_code));
439
10
        count += 1;
440
23
    }
441
33
    if let Some(value) = updates.network_addr {
442
14
        builder.set("network_addr", quote(value));
443
14
        count += 1;
444
19
    }
445
33
    if let Some(value) = updates.modified_at.as_ref() {
446
32
        builder.set("modified_at", value.timestamp_millis());
447
32
        count += 1;
448
32
    }
449
33
    if let Some(value) = updates.profile.as_ref() {
450
20
        builder.set("profile", quote(value));
451
20
        count += 1;
452
20
    }
453
33
    if let Some(value) = updates.name.as_ref() {
454
20
        builder.set("name", quote(value));
455
20
        count += 1;
456
20
    }
457
33
    if let Some(value) = updates.info {
458
18
        match serde_json::to_string(value) {
459
            Err(_) => {
460
                builder.set("info", quote("{}"));
461
            }
462
18
            Ok(value) => {
463
18
                builder.set("info", quote(value));
464
18
            }
465
        }
466
18
        count += 1;
467
15
    }
468
33
    if count == 0 {
469
1
        return None;
470
32
    }
471
32

            
472
32
    builder.and_where_eq("device_id", quote(cond.device_id));
473
32
    Some(builder)
474
33
}