1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::super::device_route::{
10
    Cursor, DeviceRoute, DeviceRouteModel, ListOptions, ListQueryCond, QueryCond, SortKey,
11
    UpdateQueryCond, Updates,
12
};
13

            
14
/// Model instance.
15
pub struct Model {
16
    /// The associated database connection.
17
    conn: Arc<SqlitePool>,
18
}
19

            
20
/// Cursor instance.
21
///
22
/// The SQLite implementation uses the original list options and the progress offset.
23
pub struct DbCursor {
24
    offset: u64,
25
}
26

            
27
/// SQLite schema.
28
#[derive(sqlx::FromRow)]
29
struct Schema {
30
    route_id: String,
31
    unit_id: String,
32
    unit_code: String,
33
    application_id: String,
34
    application_code: String,
35
    device_id: String,
36
    network_id: String,
37
    network_code: String,
38
    network_addr: String,
39
    profile: String,
40
    /// i64 as time tick from Epoch in milliseconds.
41
    created_at: i64,
42
    /// i64 as time tick from Epoch in milliseconds.
43
    modified_at: i64,
44
}
45

            
46
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
47
#[derive(sqlx::FromRow)]
48
struct CountSchema {
49
    #[sqlx(rename = "COUNT(*)")]
50
    count: i64,
51
}
52

            
53
const TABLE_NAME: &'static str = "device_route";
54
const FIELDS: &'static [&'static str] = &[
55
    "route_id",
56
    "unit_id",
57
    "unit_code",
58
    "application_id",
59
    "application_code",
60
    "device_id",
61
    "network_id",
62
    "network_code",
63
    "network_addr",
64
    "profile",
65
    "created_at",
66
    "modified_at",
67
];
68
const TABLE_INIT_SQL: &'static str = "\
69
    CREATE TABLE IF NOT EXISTS device_route (\
70
    route_id TEXT NOT NULL UNIQUE,\
71
    unit_id TEXT NOT NULL,\
72
    unit_code TEXT NOT NULL,\
73
    application_id TEXT NOT NULL,\
74
    application_code TEXT NOT NULL,\
75
    device_id TEXT NOT NULL,\
76
    network_id TEXT NOT NULL,\
77
    network_code TEXT NOT NULL,\
78
    network_addr TEXT NOT NULL,\
79
    profile TEXT NOT NULL,\
80
    created_at INTEGER NOT NULL,\
81
    modified_at INTEGER NOT NULL,\
82
    UNIQUE (application_id,device_id),\
83
    PRIMARY KEY (route_id))";
84

            
85
impl Model {
86
    /// To create the model instance with a database connection.
87
56
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
88
56
        let model = Model { conn };
89
56
        model.init().await?;
90
56
        Ok(model)
91
56
    }
92
}
93

            
94
#[async_trait]
95
impl DeviceRouteModel for Model {
96
96
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
97
96
        let _ = sqlx::query(TABLE_INIT_SQL)
98
96
            .execute(self.conn.as_ref())
99
96
            .await?;
100
96
        Ok(())
101
192
    }
102

            
103
220
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
104
220
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
105

            
106
220
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
107
220
            .fetch_one(self.conn.as_ref())
108
220
            .await;
109

            
110
220
        let row = match result {
111
            Err(e) => return Err(Box::new(e)),
112
220
            Ok(row) => row,
113
220
        };
114
220
        Ok(row.count as u64)
115
440
    }
116

            
117
    async fn list(
118
        &self,
119
        opts: &ListOptions,
120
        cursor: Option<Box<dyn Cursor>>,
121
678
    ) -> Result<(Vec<DeviceRoute>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
122
678
        let mut cursor = match cursor {
123
610
            None => Box::new(DbCursor::new()),
124
68
            Some(cursor) => cursor,
125
        };
126

            
127
678
        let mut opts = ListOptions { ..*opts };
128
678
        if let Some(offset) = opts.offset {
129
48
            opts.offset = Some(offset + cursor.offset());
130
630
        } else {
131
630
            opts.offset = Some(cursor.offset());
132
630
        }
133
678
        let opts_limit = opts.limit;
134
678
        if let Some(limit) = opts_limit {
135
400
            if limit > 0 {
136
398
                if cursor.offset() >= limit {
137
34
                    return Ok((vec![], None));
138
364
                }
139
364
                opts.limit = Some(limit - cursor.offset());
140
2
            }
141
278
        }
142
644
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
143
644
        build_limit_offset(&mut builder, &opts);
144
644
        build_sort(&mut builder, &opts);
145
644
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
146

            
147
644
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
148
644

            
149
644
        let mut count: u64 = 0;
150
644
        let mut list = vec![];
151
73856
        while let Some(row) = rows.try_next().await? {
152
73308
            let _ = cursor.as_mut().try_next().await?;
153
73308
            list.push(DeviceRoute {
154
73308
                route_id: row.route_id,
155
73308
                unit_id: row.unit_id,
156
73308
                unit_code: row.unit_code,
157
73308
                application_id: row.application_id,
158
73308
                application_code: row.application_code,
159
73308
                device_id: row.device_id,
160
73308
                network_id: row.network_id,
161
73308
                network_code: row.network_code,
162
73308
                network_addr: row.network_addr,
163
73308
                profile: row.profile,
164
73308
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
165
73308
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
166
73308
            });
167
73308
            if let Some(limit) = opts_limit {
168
4758
                if limit > 0 && cursor.offset() >= limit {
169
62
                    if let Some(cursor_max) = opts.cursor_max {
170
52
                        if (count + 1) >= cursor_max {
171
34
                            return Ok((list, Some(cursor)));
172
18
                        }
173
10
                    }
174
28
                    return Ok((list, None));
175
4696
                }
176
68550
            }
177
73246
            if let Some(cursor_max) = opts.cursor_max {
178
7108
                count += 1;
179
7108
                if count >= cursor_max {
180
34
                    return Ok((list, Some(cursor)));
181
7074
                }
182
66138
            }
183
        }
184
548
        Ok((list, None))
185
1356
    }
186

            
187
406
    async fn get(&self, route_id: &str) -> Result<Option<DeviceRoute>, Box<dyn StdError>> {
188
406
        let sql = SqlBuilder::select_from(TABLE_NAME)
189
406
            .fields(FIELDS)
190
406
            .and_where_eq("route_id", quote(route_id))
191
406
            .sql()?;
192

            
193
406
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
194
406
            .fetch_one(self.conn.as_ref())
195
406
            .await;
196

            
197
406
        let row = match result {
198
126
            Err(e) => match e {
199
126
                sqlx::Error::RowNotFound => return Ok(None),
200
                _ => return Err(Box::new(e)),
201
            },
202
280
            Ok(row) => row,
203
280
        };
204
280

            
205
280
        Ok(Some(DeviceRoute {
206
280
            route_id: row.route_id,
207
280
            unit_id: row.unit_id,
208
280
            unit_code: row.unit_code,
209
280
            application_id: row.application_id,
210
280
            application_code: row.application_code,
211
280
            device_id: row.device_id,
212
280
            network_id: row.network_id,
213
280
            network_code: row.network_code,
214
280
            network_addr: row.network_addr,
215
280
            profile: row.profile,
216
280
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
217
280
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
218
280
        }))
219
812
    }
220

            
221
2156
    async fn add(&self, route: &DeviceRoute) -> Result<(), Box<dyn StdError>> {
222
2156
        let values = vec![
223
2156
            quote(route.route_id.as_str()),
224
2156
            quote(route.unit_id.as_str()),
225
2156
            quote(route.unit_code.as_str()),
226
2156
            quote(route.application_id.as_str()),
227
2156
            quote(route.application_code.as_str()),
228
2156
            quote(route.device_id.as_str()),
229
2156
            quote(route.network_id.as_str()),
230
2156
            quote(route.network_code.as_str()),
231
2156
            quote(route.network_addr.as_str()),
232
2156
            quote(route.profile.as_str()),
233
2156
            route.created_at.timestamp_millis().to_string(),
234
2156
            route.modified_at.timestamp_millis().to_string(),
235
2156
        ];
236
2156
        let sql = SqlBuilder::insert_into(TABLE_NAME)
237
2156
            .fields(FIELDS)
238
2156
            .values(&values)
239
2156
            .sql()?;
240
2156
        let _ = sqlx::query(sql.as_str())
241
2156
            .execute(self.conn.as_ref())
242
2156
            .await?;
243
2152
        Ok(())
244
4312
    }
245

            
246
70
    async fn add_bulk(&self, routes: &Vec<DeviceRoute>) -> Result<(), Box<dyn StdError>> {
247
70
        let mut builder = SqlBuilder::insert_into(TABLE_NAME);
248
70
        builder.fields(FIELDS);
249

            
250
57954
        for route in routes.iter() {
251
57954
            builder.values(&vec![
252
57954
                quote(route.route_id.as_str()),
253
57954
                quote(route.unit_id.as_str()),
254
57954
                quote(route.unit_code.as_str()),
255
57954
                quote(route.application_id.as_str()),
256
57954
                quote(route.application_code.as_str()),
257
57954
                quote(route.device_id.as_str()),
258
57954
                quote(route.network_id.as_str()),
259
57954
                quote(route.network_code.as_str()),
260
57954
                quote(route.network_addr.as_str()),
261
57954
                quote(route.profile.as_str()),
262
57954
                route.created_at.timestamp_millis().to_string(),
263
57954
                route.modified_at.timestamp_millis().to_string(),
264
57954
            ]);
265
57954
        }
266
70
        let sql = builder.sql()?.replace(");", ") ON CONFLICT DO NOTHING;");
267
70
        let _ = sqlx::query(sql.as_str())
268
70
            .execute(self.conn.as_ref())
269
70
            .await?;
270
70
        Ok(())
271
140
    }
272

            
273
240
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
274
240
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
275
240
        let _ = sqlx::query(sql.as_str())
276
240
            .execute(self.conn.as_ref())
277
240
            .await?;
278
240
        Ok(())
279
480
    }
280

            
281
    async fn update(
282
        &self,
283
        cond: &UpdateQueryCond,
284
        updates: &Updates,
285
42
    ) -> Result<(), Box<dyn StdError>> {
286
42
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
287
        {
288
2
            None => return Ok(()),
289
40
            Some(builder) => builder.sql()?,
290
        };
291
40
        let _ = sqlx::query(sql.as_str())
292
40
            .execute(self.conn.as_ref())
293
40
            .await?;
294
40
        Ok(())
295
84
    }
296
}
297

            
298
impl DbCursor {
299
    /// To create the cursor instance.
300
610
    pub fn new() -> Self {
301
610
        DbCursor { offset: 0 }
302
610
    }
303
}
304

            
305
#[async_trait]
306
impl Cursor for DbCursor {
307
73308
    async fn try_next(&mut self) -> Result<Option<DeviceRoute>, Box<dyn StdError>> {
308
73308
        self.offset += 1;
309
73308
        Ok(None)
310
146616
    }
311

            
312
6198
    fn offset(&self) -> u64 {
313
6198
        self.offset
314
6198
    }
315
}
316

            
317
/// Transforms query conditions to the SQL builder.
318
240
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
319
240
    if let Some(value) = cond.route_id {
320
16
        builder.and_where_eq("route_id", quote(value));
321
224
    }
322
240
    if let Some(value) = cond.unit_id {
323
72
        builder.and_where_eq("unit_id", quote(value));
324
168
    }
325
240
    if let Some(value) = cond.application_id {
326
38
        builder.and_where_eq("application_id", quote(value));
327
202
    }
328
240
    if let Some(value) = cond.network_id {
329
66
        builder.and_where_eq("network_id", quote(value));
330
174
    }
331
240
    if let Some(value) = cond.device_id {
332
90
        builder.and_where_eq("device_id", quote(value));
333
150
    }
334
240
    if let Some(value) = cond.network_addrs {
335
32976
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
336
44
        builder.and_where_in("network_addr", &values);
337
196
    }
338
240
    builder
339
240
}
340

            
341
/// Transforms query conditions to the SQL builder.
342
864
fn build_list_where<'a>(
343
864
    builder: &'a mut SqlBuilder,
344
864
    cond: &ListQueryCond<'a>,
345
864
) -> &'a mut SqlBuilder {
346
864
    if let Some(value) = cond.route_id {
347
68
        builder.and_where_eq("route_id", quote(value));
348
796
    }
349
864
    if let Some(value) = cond.unit_id {
350
260
        builder.and_where_eq("unit_id", quote(value));
351
604
    }
352
864
    if let Some(value) = cond.unit_code {
353
12
        builder.and_where_eq("unit_code", quote(value));
354
852
    }
355
864
    if let Some(value) = cond.application_id {
356
288
        builder.and_where_eq("application_id", quote(value));
357
576
    }
358
864
    if let Some(value) = cond.application_code {
359
12
        builder.and_where_eq("application_code", quote(value));
360
852
    }
361
864
    if let Some(value) = cond.network_id {
362
216
        builder.and_where_eq("network_id", quote(value));
363
648
    }
364
864
    if let Some(value) = cond.network_code {
365
12
        builder.and_where_eq("network_code", quote(value));
366
852
    }
367
864
    if let Some(value) = cond.network_addr {
368
12
        builder.and_where_eq("network_addr", quote(value));
369
852
    }
370
864
    if let Some(value) = cond.network_addrs {
371
24
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
372
8
        builder.and_where_in("network_addr", &values);
373
856
    }
374
864
    if let Some(value) = cond.device_id {
375
210
        builder.and_where_eq("device_id", quote(value));
376
654
    }
377
864
    builder
378
864
}
379

            
380
/// Transforms model options to the SQL builder.
381
644
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
382
644
    if let Some(value) = opts.limit {
383
366
        if value > 0 {
384
364
            builder.limit(value);
385
364
        }
386
278
    }
387
644
    if let Some(value) = opts.offset {
388
644
        match opts.limit {
389
278
            None => builder.limit(-1).offset(value),
390
2
            Some(0) => builder.limit(-1).offset(value),
391
364
            _ => builder.offset(value),
392
        };
393
    }
394
644
    builder
395
644
}
396

            
397
/// Transforms model options to the SQL builder.
398
644
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
399
644
    if let Some(sort_cond) = opts.sort.as_ref() {
400
1020
        for cond in sort_cond.iter() {
401
1020
            let key = match cond.key {
402
304
                SortKey::CreatedAt => "created_at",
403
12
                SortKey::ModifiedAt => "modified_at",
404
12
                SortKey::ApplicationCode => "application_code",
405
288
                SortKey::NetworkCode => "network_code",
406
404
                SortKey::NetworkAddr => "network_addr",
407
            };
408
1020
            builder.order_by(key, !cond.asc);
409
        }
410
206
    }
411
644
    builder
412
644
}
413

            
414
/// Transforms query conditions and the model object to the SQL builder.
415
42
fn build_update_where<'a>(
416
42
    builder: &'a mut SqlBuilder,
417
42
    cond: &UpdateQueryCond<'a>,
418
42
    updates: &Updates,
419
42
) -> Option<&'a mut SqlBuilder> {
420
42
    let mut count = 0;
421
42
    if let Some(value) = updates.modified_at.as_ref() {
422
40
        builder.set("modified_at", value.timestamp_millis());
423
40
        count += 1;
424
40
    }
425
42
    if let Some(value) = updates.profile.as_ref() {
426
38
        builder.set("profile", quote(value));
427
38
        count += 1;
428
38
    }
429
42
    if count == 0 {
430
2
        return None;
431
40
    }
432
40

            
433
40
    builder.and_where_eq("device_id", quote(cond.device_id));
434
40
    Some(builder)
435
42
}