1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::super::device_route::{
10
    Cursor, DeviceRoute, DeviceRouteModel, ListOptions, ListQueryCond, QueryCond, SortKey,
11
    UpdateQueryCond, Updates,
12
};
13

            
14
/// Model instance.
15
pub struct Model {
16
    /// The associated database connection.
17
    conn: Arc<SqlitePool>,
18
}
19

            
20
/// Cursor instance.
21
///
22
/// The SQLite implementation uses the original list options and the progress offset.
23
pub struct DbCursor {
24
    offset: u64,
25
}
26

            
27
/// SQLite schema.
28
#[derive(sqlx::FromRow)]
29
struct Schema {
30
    route_id: String,
31
    unit_id: String,
32
    unit_code: String,
33
    application_id: String,
34
    application_code: String,
35
    device_id: String,
36
    network_id: String,
37
    network_code: String,
38
    network_addr: String,
39
    profile: String,
40
    /// i64 as time tick from Epoch in milliseconds.
41
    created_at: i64,
42
    /// i64 as time tick from Epoch in milliseconds.
43
    modified_at: i64,
44
}
45

            
46
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
47
#[derive(sqlx::FromRow)]
48
struct CountSchema {
49
    #[sqlx(rename = "COUNT(*)")]
50
    count: i64,
51
}
52

            
53
const TABLE_NAME: &'static str = "device_route";
54
const FIELDS: &'static [&'static str] = &[
55
    "route_id",
56
    "unit_id",
57
    "unit_code",
58
    "application_id",
59
    "application_code",
60
    "device_id",
61
    "network_id",
62
    "network_code",
63
    "network_addr",
64
    "profile",
65
    "created_at",
66
    "modified_at",
67
];
68
const TABLE_INIT_SQL: &'static str = "\
69
    CREATE TABLE IF NOT EXISTS device_route (\
70
    route_id TEXT NOT NULL UNIQUE,\
71
    unit_id TEXT NOT NULL,\
72
    unit_code TEXT NOT NULL,\
73
    application_id TEXT NOT NULL,\
74
    application_code TEXT NOT NULL,\
75
    device_id TEXT NOT NULL,\
76
    network_id TEXT NOT NULL,\
77
    network_code TEXT NOT NULL,\
78
    network_addr TEXT NOT NULL,\
79
    profile TEXT NOT NULL,\
80
    created_at INTEGER NOT NULL,\
81
    modified_at INTEGER NOT NULL,\
82
    UNIQUE (application_id,device_id),\
83
    PRIMARY KEY (route_id))";
84

            
85
impl Model {
86
    /// To create the model instance with a database connection.
87
28
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
88
28
        let model = Model { conn };
89
28
        model.init().await?;
90
28
        Ok(model)
91
28
    }
92
}
93

            
94
#[async_trait]
95
impl DeviceRouteModel for Model {
96
48
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
97
48
        let _ = sqlx::query(TABLE_INIT_SQL)
98
48
            .execute(self.conn.as_ref())
99
48
            .await?;
100
48
        Ok(())
101
96
    }
102

            
103
110
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
104
110
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
105

            
106
110
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
107
110
            .fetch_one(self.conn.as_ref())
108
110
            .await;
109

            
110
110
        let row = match result {
111
            Err(e) => return Err(Box::new(e)),
112
110
            Ok(row) => row,
113
110
        };
114
110
        Ok(row.count as u64)
115
220
    }
116

            
117
    async fn list(
118
        &self,
119
        opts: &ListOptions,
120
        cursor: Option<Box<dyn Cursor>>,
121
339
    ) -> Result<(Vec<DeviceRoute>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
122
339
        let mut cursor = match cursor {
123
305
            None => Box::new(DbCursor::new()),
124
34
            Some(cursor) => cursor,
125
        };
126

            
127
339
        let mut opts = ListOptions { ..*opts };
128
339
        if let Some(offset) = opts.offset {
129
24
            opts.offset = Some(offset + cursor.offset());
130
315
        } else {
131
315
            opts.offset = Some(cursor.offset());
132
315
        }
133
339
        let opts_limit = opts.limit;
134
339
        if let Some(limit) = opts_limit {
135
200
            if limit > 0 {
136
199
                if cursor.offset() >= limit {
137
17
                    return Ok((vec![], None));
138
182
                }
139
182
                opts.limit = Some(limit - cursor.offset());
140
1
            }
141
139
        }
142
322
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
143
322
        build_limit_offset(&mut builder, &opts);
144
322
        build_sort(&mut builder, &opts);
145
322
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
146

            
147
322
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
148
322

            
149
322
        let mut count: u64 = 0;
150
322
        let mut list = vec![];
151
36928
        while let Some(row) = rows.try_next().await? {
152
36654
            let _ = cursor.as_mut().try_next().await?;
153
36654
            list.push(DeviceRoute {
154
36654
                route_id: row.route_id,
155
36654
                unit_id: row.unit_id,
156
36654
                unit_code: row.unit_code,
157
36654
                application_id: row.application_id,
158
36654
                application_code: row.application_code,
159
36654
                device_id: row.device_id,
160
36654
                network_id: row.network_id,
161
36654
                network_code: row.network_code,
162
36654
                network_addr: row.network_addr,
163
36654
                profile: row.profile,
164
36654
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
165
36654
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
166
36654
            });
167
36654
            if let Some(limit) = opts_limit {
168
2379
                if limit > 0 && cursor.offset() >= limit {
169
31
                    if let Some(cursor_max) = opts.cursor_max {
170
26
                        if (count + 1) >= cursor_max {
171
17
                            return Ok((list, Some(cursor)));
172
9
                        }
173
5
                    }
174
14
                    return Ok((list, None));
175
2348
                }
176
34275
            }
177
36623
            if let Some(cursor_max) = opts.cursor_max {
178
3554
                count += 1;
179
3554
                if count >= cursor_max {
180
17
                    return Ok((list, Some(cursor)));
181
3537
                }
182
33069
            }
183
        }
184
274
        Ok((list, None))
185
678
    }
186

            
187
203
    async fn get(&self, route_id: &str) -> Result<Option<DeviceRoute>, Box<dyn StdError>> {
188
203
        let sql = SqlBuilder::select_from(TABLE_NAME)
189
203
            .fields(FIELDS)
190
203
            .and_where_eq("route_id", quote(route_id))
191
203
            .sql()?;
192

            
193
203
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
194
203
            .fetch_one(self.conn.as_ref())
195
203
            .await;
196

            
197
203
        let row = match result {
198
63
            Err(e) => match e {
199
63
                sqlx::Error::RowNotFound => return Ok(None),
200
                _ => return Err(Box::new(e)),
201
            },
202
140
            Ok(row) => row,
203
140
        };
204
140

            
205
140
        Ok(Some(DeviceRoute {
206
140
            route_id: row.route_id,
207
140
            unit_id: row.unit_id,
208
140
            unit_code: row.unit_code,
209
140
            application_id: row.application_id,
210
140
            application_code: row.application_code,
211
140
            device_id: row.device_id,
212
140
            network_id: row.network_id,
213
140
            network_code: row.network_code,
214
140
            network_addr: row.network_addr,
215
140
            profile: row.profile,
216
140
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
217
140
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
218
140
        }))
219
406
    }
220

            
221
1078
    async fn add(&self, route: &DeviceRoute) -> Result<(), Box<dyn StdError>> {
222
1078
        let values = vec![
223
1078
            quote(route.route_id.as_str()),
224
1078
            quote(route.unit_id.as_str()),
225
1078
            quote(route.unit_code.as_str()),
226
1078
            quote(route.application_id.as_str()),
227
1078
            quote(route.application_code.as_str()),
228
1078
            quote(route.device_id.as_str()),
229
1078
            quote(route.network_id.as_str()),
230
1078
            quote(route.network_code.as_str()),
231
1078
            quote(route.network_addr.as_str()),
232
1078
            quote(route.profile.as_str()),
233
1078
            route.created_at.timestamp_millis().to_string(),
234
1078
            route.modified_at.timestamp_millis().to_string(),
235
1078
        ];
236
1078
        let sql = SqlBuilder::insert_into(TABLE_NAME)
237
1078
            .fields(FIELDS)
238
1078
            .values(&values)
239
1078
            .sql()?;
240
1078
        let _ = sqlx::query(sql.as_str())
241
1078
            .execute(self.conn.as_ref())
242
1078
            .await?;
243
1076
        Ok(())
244
2156
    }
245

            
246
35
    async fn add_bulk(&self, routes: &Vec<DeviceRoute>) -> Result<(), Box<dyn StdError>> {
247
35
        let mut builder = SqlBuilder::insert_into(TABLE_NAME);
248
35
        builder.fields(FIELDS);
249

            
250
28977
        for route in routes.iter() {
251
28977
            builder.values(&vec![
252
28977
                quote(route.route_id.as_str()),
253
28977
                quote(route.unit_id.as_str()),
254
28977
                quote(route.unit_code.as_str()),
255
28977
                quote(route.application_id.as_str()),
256
28977
                quote(route.application_code.as_str()),
257
28977
                quote(route.device_id.as_str()),
258
28977
                quote(route.network_id.as_str()),
259
28977
                quote(route.network_code.as_str()),
260
28977
                quote(route.network_addr.as_str()),
261
28977
                quote(route.profile.as_str()),
262
28977
                route.created_at.timestamp_millis().to_string(),
263
28977
                route.modified_at.timestamp_millis().to_string(),
264
28977
            ]);
265
28977
        }
266
35
        let sql = builder.sql()?.replace(");", ") ON CONFLICT DO NOTHING;");
267
35
        let _ = sqlx::query(sql.as_str())
268
35
            .execute(self.conn.as_ref())
269
35
            .await?;
270
35
        Ok(())
271
70
    }
272

            
273
120
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
274
120
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
275
120
        let _ = sqlx::query(sql.as_str())
276
120
            .execute(self.conn.as_ref())
277
120
            .await?;
278
120
        Ok(())
279
240
    }
280

            
281
    async fn update(
282
        &self,
283
        cond: &UpdateQueryCond,
284
        updates: &Updates,
285
21
    ) -> Result<(), Box<dyn StdError>> {
286
21
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
287
        {
288
1
            None => return Ok(()),
289
20
            Some(builder) => builder.sql()?,
290
        };
291
20
        let _ = sqlx::query(sql.as_str())
292
20
            .execute(self.conn.as_ref())
293
20
            .await?;
294
20
        Ok(())
295
42
    }
296
}
297

            
298
impl DbCursor {
299
    /// To create the cursor instance.
300
305
    pub fn new() -> Self {
301
305
        DbCursor { offset: 0 }
302
305
    }
303
}
304

            
305
#[async_trait]
306
impl Cursor for DbCursor {
307
36654
    async fn try_next(&mut self) -> Result<Option<DeviceRoute>, Box<dyn StdError>> {
308
36654
        self.offset += 1;
309
36654
        Ok(None)
310
73308
    }
311

            
312
3099
    fn offset(&self) -> u64 {
313
3099
        self.offset
314
3099
    }
315
}
316

            
317
/// Transforms query conditions to the SQL builder.
318
120
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
319
120
    if let Some(value) = cond.route_id {
320
8
        builder.and_where_eq("route_id", quote(value));
321
112
    }
322
120
    if let Some(value) = cond.unit_id {
323
36
        builder.and_where_eq("unit_id", quote(value));
324
84
    }
325
120
    if let Some(value) = cond.application_id {
326
19
        builder.and_where_eq("application_id", quote(value));
327
101
    }
328
120
    if let Some(value) = cond.network_id {
329
33
        builder.and_where_eq("network_id", quote(value));
330
87
    }
331
120
    if let Some(value) = cond.device_id {
332
45
        builder.and_where_eq("device_id", quote(value));
333
75
    }
334
120
    if let Some(value) = cond.network_addrs {
335
16488
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
336
22
        builder.and_where_in("network_addr", &values);
337
98
    }
338
120
    builder
339
120
}
340

            
341
/// Transforms query conditions to the SQL builder.
342
432
fn build_list_where<'a>(
343
432
    builder: &'a mut SqlBuilder,
344
432
    cond: &ListQueryCond<'a>,
345
432
) -> &'a mut SqlBuilder {
346
432
    if let Some(value) = cond.route_id {
347
34
        builder.and_where_eq("route_id", quote(value));
348
398
    }
349
432
    if let Some(value) = cond.unit_id {
350
130
        builder.and_where_eq("unit_id", quote(value));
351
302
    }
352
432
    if let Some(value) = cond.unit_code {
353
6
        builder.and_where_eq("unit_code", quote(value));
354
426
    }
355
432
    if let Some(value) = cond.application_id {
356
144
        builder.and_where_eq("application_id", quote(value));
357
288
    }
358
432
    if let Some(value) = cond.application_code {
359
6
        builder.and_where_eq("application_code", quote(value));
360
426
    }
361
432
    if let Some(value) = cond.network_id {
362
108
        builder.and_where_eq("network_id", quote(value));
363
324
    }
364
432
    if let Some(value) = cond.network_code {
365
6
        builder.and_where_eq("network_code", quote(value));
366
426
    }
367
432
    if let Some(value) = cond.network_addr {
368
6
        builder.and_where_eq("network_addr", quote(value));
369
426
    }
370
432
    if let Some(value) = cond.network_addrs {
371
12
        let values: Vec<String> = value.iter().map(|&x| quote(x)).collect();
372
4
        builder.and_where_in("network_addr", &values);
373
428
    }
374
432
    if let Some(value) = cond.device_id {
375
105
        builder.and_where_eq("device_id", quote(value));
376
327
    }
377
432
    builder
378
432
}
379

            
380
/// Transforms model options to the SQL builder.
381
322
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
382
322
    if let Some(value) = opts.limit {
383
183
        if value > 0 {
384
182
            builder.limit(value);
385
182
        }
386
139
    }
387
322
    if let Some(value) = opts.offset {
388
322
        match opts.limit {
389
139
            None => builder.limit(-1).offset(value),
390
1
            Some(0) => builder.limit(-1).offset(value),
391
182
            _ => builder.offset(value),
392
        };
393
    }
394
322
    builder
395
322
}
396

            
397
/// Transforms model options to the SQL builder.
398
322
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
399
322
    if let Some(sort_cond) = opts.sort.as_ref() {
400
510
        for cond in sort_cond.iter() {
401
510
            let key = match cond.key {
402
152
                SortKey::CreatedAt => "created_at",
403
6
                SortKey::ModifiedAt => "modified_at",
404
6
                SortKey::ApplicationCode => "application_code",
405
144
                SortKey::NetworkCode => "network_code",
406
202
                SortKey::NetworkAddr => "network_addr",
407
            };
408
510
            builder.order_by(key, !cond.asc);
409
        }
410
103
    }
411
322
    builder
412
322
}
413

            
414
/// Transforms query conditions and the model object to the SQL builder.
415
21
fn build_update_where<'a>(
416
21
    builder: &'a mut SqlBuilder,
417
21
    cond: &UpdateQueryCond<'a>,
418
21
    updates: &Updates,
419
21
) -> Option<&'a mut SqlBuilder> {
420
21
    let mut count = 0;
421
21
    if let Some(value) = updates.modified_at.as_ref() {
422
20
        builder.set("modified_at", value.timestamp_millis());
423
20
        count += 1;
424
20
    }
425
21
    if let Some(value) = updates.profile.as_ref() {
426
19
        builder.set("profile", quote(value));
427
19
        count += 1;
428
19
    }
429
21
    if count == 0 {
430
1
        return None;
431
20
    }
432
20

            
433
20
    builder.and_where_eq("device_id", quote(cond.device_id));
434
20
    Some(builder)
435
21
}