1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::{
10
    super::network::{
11
        Cursor, ListOptions, ListQueryCond, Network, NetworkModel, QueryCond, SortKey,
12
        UpdateQueryCond, Updates,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    network_id: String,
34
    code: String,
35
    /// use empty string as NULL because duplicate `(unit_id=NULL,code="code")` is allowed.
36
    unit_id: String,
37
    /// use empty string as NULL.
38
    unit_code: String,
39
    /// i64 as time tick from Epoch in milliseconds.
40
    created_at: i64,
41
    /// i64 as time tick from Epoch in milliseconds.
42
    modified_at: i64,
43
    host_uri: String,
44
    name: String,
45
    info: String,
46
}
47

            
48
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
49
#[derive(sqlx::FromRow)]
50
struct CountSchema {
51
    #[sqlx(rename = "COUNT(*)")]
52
    count: i64,
53
}
54

            
55
const TABLE_NAME: &'static str = "network";
56
const FIELDS: &'static [&'static str] = &[
57
    "network_id",
58
    "code",
59
    "unit_id",
60
    "unit_code",
61
    "created_at",
62
    "modified_at",
63
    "host_uri",
64
    "name",
65
    "info",
66
];
67
const TABLE_INIT_SQL: &'static str = "\
68
    CREATE TABLE IF NOT EXISTS network (\
69
    network_id TEXT NOT NULL UNIQUE,\
70
    code TEXT NOT NULL,\
71
    unit_id TEXT NOT NULL,\
72
    unit_code TEXT NOT NULL,\
73
    created_at INTEGER NOT NULL,\
74
    modified_at INTEGER NOT NULL,\
75
    host_uri TEXT NOT NULL,\
76
    name TEXT NOT NULL,\
77
    info TEXT,\
78
    UNIQUE (unit_id,code),\
79
    PRIMARY KEY (network_id))";
80

            
81
impl Model {
82
    /// To create the model instance with a database connection.
83
28
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
84
28
        let model = Model { conn };
85
28
        model.init().await?;
86
28
        Ok(model)
87
28
    }
88
}
89

            
90
#[async_trait]
91
impl NetworkModel for Model {
92
48
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
93
48
        let _ = sqlx::query(TABLE_INIT_SQL)
94
48
            .execute(self.conn.as_ref())
95
48
            .await?;
96
48
        Ok(())
97
96
    }
98

            
99
46
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
100
46
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
101

            
102
46
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
103
46
            .fetch_one(self.conn.as_ref())
104
46
            .await;
105

            
106
46
        let row = match result {
107
            Err(e) => return Err(Box::new(e)),
108
46
            Ok(row) => row,
109
46
        };
110
46
        Ok(row.count as u64)
111
92
    }
112

            
113
    async fn list(
114
        &self,
115
        opts: &ListOptions,
116
        cursor: Option<Box<dyn Cursor>>,
117
171
    ) -> Result<(Vec<Network>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
118
171
        let mut cursor = match cursor {
119
137
            None => Box::new(DbCursor::new()),
120
34
            Some(cursor) => cursor,
121
        };
122

            
123
171
        let mut opts = ListOptions { ..*opts };
124
171
        if let Some(offset) = opts.offset {
125
24
            opts.offset = Some(offset + cursor.offset());
126
147
        } else {
127
147
            opts.offset = Some(cursor.offset());
128
147
        }
129
171
        let opts_limit = opts.limit;
130
171
        if let Some(limit) = opts_limit {
131
90
            if limit > 0 {
132
89
                if cursor.offset() >= limit {
133
9
                    return Ok((vec![], None));
134
80
                }
135
80
                opts.limit = Some(limit - cursor.offset());
136
1
            }
137
81
        }
138
162
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
139
162
        build_limit_offset(&mut builder, &opts);
140
162
        build_sort(&mut builder, &opts);
141
162
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
142

            
143
162
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
144
162

            
145
162
        let mut count: u64 = 0;
146
162
        let mut list = vec![];
147
3666
        while let Some(row) = rows.try_next().await? {
148
3548
            let _ = cursor.as_mut().try_next().await?;
149
3548
            list.push(Network {
150
3548
                network_id: row.network_id,
151
3548
                code: row.code,
152
3548
                unit_id: match row.unit_id.len() {
153
1038
                    0 => None,
154
2510
                    _ => Some(row.unit_id),
155
                },
156
3548
                unit_code: match row.unit_code.len() {
157
1038
                    0 => None,
158
2510
                    _ => Some(row.unit_code),
159
                },
160
3548
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
161
3548
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
162
3548
                host_uri: row.host_uri,
163
3548
                name: row.name,
164
3548
                info: serde_json::from_str(row.info.as_str())?,
165
            });
166
3548
            if let Some(limit) = opts_limit {
167
1453
                if limit > 0 && cursor.offset() >= limit {
168
19
                    if let Some(cursor_max) = opts.cursor_max {
169
18
                        if (count + 1) >= cursor_max {
170
9
                            return Ok((list, Some(cursor)));
171
9
                        }
172
1
                    }
173
10
                    return Ok((list, None));
174
1434
                }
175
2095
            }
176
3529
            if let Some(cursor_max) = opts.cursor_max {
177
3445
                count += 1;
178
3445
                if count >= cursor_max {
179
25
                    return Ok((list, Some(cursor)));
180
3420
                }
181
84
            }
182
        }
183
118
        Ok((list, None))
184
342
    }
185

            
186
569
    async fn get(&self, cond: &QueryCond) -> Result<Option<Network>, Box<dyn StdError>> {
187
569
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
188

            
189
569
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
190
569
            .fetch_one(self.conn.as_ref())
191
569
            .await;
192

            
193
569
        let row = match result {
194
88
            Err(e) => match e {
195
88
                sqlx::Error::RowNotFound => return Ok(None),
196
                _ => return Err(Box::new(e)),
197
            },
198
481
            Ok(row) => row,
199
481
        };
200
481

            
201
481
        Ok(Some(Network {
202
481
            network_id: row.network_id,
203
481
            code: row.code,
204
481
            unit_id: match row.unit_id.len() {
205
176
                0 => None,
206
305
                _ => Some(row.unit_id),
207
            },
208
481
            unit_code: match row.unit_code.len() {
209
176
                0 => None,
210
305
                _ => Some(row.unit_code),
211
            },
212
481
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
213
481
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
214
481
            host_uri: row.host_uri,
215
481
            name: row.name,
216
481
            info: serde_json::from_str(row.info.as_str())?,
217
        }))
218
1138
    }
219

            
220
2560
    async fn add(&self, network: &Network) -> Result<(), Box<dyn StdError>> {
221
2560
        let unit_id = match network.unit_id.as_deref() {
222
562
            None => quote(""),
223
1998
            Some(value) => quote(value),
224
        };
225
2560
        let unit_code = match network.unit_code.as_deref() {
226
562
            None => quote(""),
227
1998
            Some(value) => quote(value),
228
        };
229
2560
        let info = match serde_json::to_string(&network.info) {
230
            Err(_) => quote("{}"),
231
2560
            Ok(value) => quote(value.as_str()),
232
        };
233
2560
        let values = vec![
234
2560
            quote(network.network_id.as_str()),
235
2560
            quote(network.code.as_str()),
236
2560
            unit_id,
237
2560
            unit_code,
238
2560
            network.created_at.timestamp_millis().to_string(),
239
2560
            network.modified_at.timestamp_millis().to_string(),
240
2560
            quote(network.host_uri.as_str()),
241
2560
            quote(network.name.as_str()),
242
2560
            info,
243
2560
        ];
244
2560
        let sql = SqlBuilder::insert_into(TABLE_NAME)
245
2560
            .fields(FIELDS)
246
2560
            .values(&values)
247
2560
            .sql()?;
248
2560
        let _ = sqlx::query(sql.as_str())
249
2560
            .execute(self.conn.as_ref())
250
2560
            .await?;
251
2556
        Ok(())
252
5120
    }
253

            
254
32
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
255
32
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
256
32
        let _ = sqlx::query(sql.as_str())
257
32
            .execute(self.conn.as_ref())
258
32
            .await?;
259
32
        Ok(())
260
64
    }
261

            
262
    async fn update(
263
        &self,
264
        cond: &UpdateQueryCond,
265
        updates: &Updates,
266
17
    ) -> Result<(), Box<dyn StdError>> {
267
17
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
268
        {
269
1
            None => return Ok(()),
270
16
            Some(builder) => builder.sql()?,
271
        };
272
16
        let _ = sqlx::query(sql.as_str())
273
16
            .execute(self.conn.as_ref())
274
16
            .await?;
275
16
        Ok(())
276
34
    }
277
}
278

            
279
impl DbCursor {
280
    /// To create the cursor instance.
281
137
    pub fn new() -> Self {
282
137
        DbCursor { offset: 0 }
283
137
    }
284
}
285

            
286
#[async_trait]
287
impl Cursor for DbCursor {
288
3548
    async fn try_next(&mut self) -> Result<Option<Network>, Box<dyn StdError>> {
289
3548
        self.offset += 1;
290
3548
        Ok(None)
291
7096
    }
292

            
293
1793
    fn offset(&self) -> u64 {
294
1793
        self.offset
295
1793
    }
296
}
297

            
298
/// Transforms query conditions to the SQL builder.
299
601
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
300
601
    if let Some(value) = cond.unit_id {
301
53
        match value {
302
14
            None => {
303
14
                builder.and_where_eq("unit_id", quote(""));
304
14
            }
305
39
            Some(value) => {
306
39
                builder.and_where_eq("unit_id", quote(value));
307
39
            }
308
        }
309
548
    }
310
601
    if let Some(value) = cond.network_id {
311
555
        builder.and_where_eq("network_id", quote(value));
312
555
    }
313
601
    if let Some(value) = cond.code {
314
30
        builder.and_where_eq("code", quote(value));
315
571
    }
316
601
    builder
317
601
}
318

            
319
/// Transforms query conditions to the SQL builder.
320
208
fn build_list_where<'a>(
321
208
    builder: &'a mut SqlBuilder,
322
208
    cond: &ListQueryCond<'a>,
323
208
) -> &'a mut SqlBuilder {
324
208
    if let Some(value) = cond.unit_id {
325
94
        match value {
326
18
            None => {
327
18
                builder.and_where_eq("unit_id", quote(""));
328
18
            }
329
76
            Some(value) => {
330
76
                builder.and_where_eq("unit_id", quote(value));
331
76
            }
332
        }
333
114
    }
334
208
    if let Some(value) = cond.network_id {
335
6
        builder.and_where_eq("network_id", quote(value));
336
202
    }
337
208
    if let Some(value) = cond.code {
338
24
        builder.and_where_eq("code", quote(value));
339
184
    }
340
208
    if let Some(value) = cond.code_contains {
341
59
        build_where_like(builder, "code", value.to_lowercase().as_str());
342
149
    }
343
208
    if let Some(value) = cond.name_contains {
344
8
        build_where_like(builder, "name", value.to_lowercase().as_str());
345
200
    }
346
208
    builder
347
208
}
348

            
349
/// Transforms model options to the SQL builder.
350
162
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
351
162
    if let Some(value) = opts.limit {
352
81
        if value > 0 {
353
80
            builder.limit(value);
354
80
        }
355
81
    }
356
162
    if let Some(value) = opts.offset {
357
162
        match opts.limit {
358
81
            None => builder.limit(-1).offset(value),
359
1
            Some(0) => builder.limit(-1).offset(value),
360
80
            _ => builder.offset(value),
361
        };
362
    }
363
162
    builder
364
162
}
365

            
366
/// Transforms model options to the SQL builder.
367
162
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
368
162
    if let Some(sort_cond) = opts.sort.as_ref() {
369
122
        for cond in sort_cond.iter() {
370
122
            let key = match cond.key {
371
8
                SortKey::CreatedAt => "created_at",
372
6
                SortKey::ModifiedAt => "modified_at",
373
100
                SortKey::Code => "code",
374
8
                SortKey::Name => "name",
375
            };
376
122
            builder.order_by(key, !cond.asc);
377
        }
378
41
    }
379
162
    builder
380
162
}
381

            
382
/// Transforms query conditions and the model object to the SQL builder.
383
17
fn build_update_where<'a>(
384
17
    builder: &'a mut SqlBuilder,
385
17
    cond: &UpdateQueryCond<'a>,
386
17
    updates: &Updates,
387
17
) -> Option<&'a mut SqlBuilder> {
388
17
    let mut count = 0;
389
17
    if let Some(value) = updates.modified_at.as_ref() {
390
16
        builder.set("modified_at", value.timestamp_millis());
391
16
        count += 1;
392
16
    }
393
17
    if let Some(value) = updates.host_uri.as_ref() {
394
8
        builder.set("host_uri", quote(value));
395
8
        count += 1;
396
9
    }
397
17
    if let Some(value) = updates.name.as_ref() {
398
14
        builder.set("name", quote(value));
399
14
        count += 1;
400
14
    }
401
17
    if let Some(value) = updates.info {
402
14
        match serde_json::to_string(value) {
403
            Err(_) => {
404
                builder.set("info", quote("{}"));
405
            }
406
14
            Ok(value) => {
407
14
                builder.set("info", quote(value));
408
14
            }
409
        }
410
14
        count += 1;
411
3
    }
412
17
    if count == 0 {
413
1
        return None;
414
16
    }
415
16

            
416
16
    builder.and_where_eq("network_id", quote(cond.network_id));
417
16
    Some(builder)
418
17
}