1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{quote, SqlBuilder};
7
use sqlx::SqlitePool;
8

            
9
use super::{
10
    super::network::{
11
        Cursor, ListOptions, ListQueryCond, Network, NetworkModel, QueryCond, SortKey,
12
        UpdateQueryCond, Updates,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    network_id: String,
34
    code: String,
35
    /// use empty string as NULL because duplicate `(unit_id=NULL,code="code")` is allowed.
36
    unit_id: String,
37
    /// use empty string as NULL.
38
    unit_code: String,
39
    /// i64 as time tick from Epoch in milliseconds.
40
    created_at: i64,
41
    /// i64 as time tick from Epoch in milliseconds.
42
    modified_at: i64,
43
    host_uri: String,
44
    name: String,
45
    info: String,
46
}
47

            
48
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
49
#[derive(sqlx::FromRow)]
50
struct CountSchema {
51
    #[sqlx(rename = "COUNT(*)")]
52
    count: i64,
53
}
54

            
55
const TABLE_NAME: &'static str = "network";
56
const FIELDS: &'static [&'static str] = &[
57
    "network_id",
58
    "code",
59
    "unit_id",
60
    "unit_code",
61
    "created_at",
62
    "modified_at",
63
    "host_uri",
64
    "name",
65
    "info",
66
];
67
const TABLE_INIT_SQL: &'static str = "\
68
    CREATE TABLE IF NOT EXISTS network (\
69
    network_id TEXT NOT NULL UNIQUE,\
70
    code TEXT NOT NULL,\
71
    unit_id TEXT NOT NULL,\
72
    unit_code TEXT NOT NULL,\
73
    created_at INTEGER NOT NULL,\
74
    modified_at INTEGER NOT NULL,\
75
    host_uri TEXT NOT NULL,\
76
    name TEXT NOT NULL,\
77
    info TEXT,\
78
    UNIQUE (unit_id,code),\
79
    PRIMARY KEY (network_id))";
80

            
81
impl Model {
82
    /// To create the model instance with a database connection.
83
56
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
84
56
        let model = Model { conn };
85
56
        model.init().await?;
86
56
        Ok(model)
87
56
    }
88
}
89

            
90
#[async_trait]
91
impl NetworkModel for Model {
92
96
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
93
96
        let _ = sqlx::query(TABLE_INIT_SQL)
94
96
            .execute(self.conn.as_ref())
95
96
            .await?;
96
96
        Ok(())
97
192
    }
98

            
99
92
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
100
92
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
101

            
102
92
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
103
92
            .fetch_one(self.conn.as_ref())
104
92
            .await;
105

            
106
92
        let row = match result {
107
            Err(e) => return Err(Box::new(e)),
108
92
            Ok(row) => row,
109
92
        };
110
92
        Ok(row.count as u64)
111
184
    }
112

            
113
    async fn list(
114
        &self,
115
        opts: &ListOptions,
116
        cursor: Option<Box<dyn Cursor>>,
117
342
    ) -> Result<(Vec<Network>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
118
342
        let mut cursor = match cursor {
119
274
            None => Box::new(DbCursor::new()),
120
68
            Some(cursor) => cursor,
121
        };
122

            
123
342
        let mut opts = ListOptions { ..*opts };
124
342
        if let Some(offset) = opts.offset {
125
48
            opts.offset = Some(offset + cursor.offset());
126
294
        } else {
127
294
            opts.offset = Some(cursor.offset());
128
294
        }
129
342
        let opts_limit = opts.limit;
130
342
        if let Some(limit) = opts_limit {
131
180
            if limit > 0 {
132
178
                if cursor.offset() >= limit {
133
18
                    return Ok((vec![], None));
134
160
                }
135
160
                opts.limit = Some(limit - cursor.offset());
136
2
            }
137
162
        }
138
324
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
139
324
        build_limit_offset(&mut builder, &opts);
140
324
        build_sort(&mut builder, &opts);
141
324
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
142

            
143
324
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
144
324

            
145
324
        let mut count: u64 = 0;
146
324
        let mut list = vec![];
147
7332
        while let Some(row) = rows.try_next().await? {
148
7096
            let _ = cursor.as_mut().try_next().await?;
149
7096
            list.push(Network {
150
7096
                network_id: row.network_id,
151
7096
                code: row.code,
152
7096
                unit_id: match row.unit_id.len() {
153
2076
                    0 => None,
154
5020
                    _ => Some(row.unit_id),
155
                },
156
7096
                unit_code: match row.unit_code.len() {
157
2076
                    0 => None,
158
5020
                    _ => Some(row.unit_code),
159
                },
160
7096
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
161
7096
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
162
7096
                host_uri: row.host_uri,
163
7096
                name: row.name,
164
7096
                info: serde_json::from_str(row.info.as_str())?,
165
            });
166
7096
            if let Some(limit) = opts_limit {
167
2906
                if limit > 0 && cursor.offset() >= limit {
168
38
                    if let Some(cursor_max) = opts.cursor_max {
169
36
                        if (count + 1) >= cursor_max {
170
18
                            return Ok((list, Some(cursor)));
171
18
                        }
172
2
                    }
173
20
                    return Ok((list, None));
174
2868
                }
175
4190
            }
176
7058
            if let Some(cursor_max) = opts.cursor_max {
177
6890
                count += 1;
178
6890
                if count >= cursor_max {
179
50
                    return Ok((list, Some(cursor)));
180
6840
                }
181
168
            }
182
        }
183
236
        Ok((list, None))
184
684
    }
185

            
186
1138
    async fn get(&self, cond: &QueryCond) -> Result<Option<Network>, Box<dyn StdError>> {
187
1138
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
188

            
189
1138
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
190
1138
            .fetch_one(self.conn.as_ref())
191
1138
            .await;
192

            
193
1138
        let row = match result {
194
176
            Err(e) => match e {
195
176
                sqlx::Error::RowNotFound => return Ok(None),
196
                _ => return Err(Box::new(e)),
197
            },
198
962
            Ok(row) => row,
199
962
        };
200
962

            
201
962
        Ok(Some(Network {
202
962
            network_id: row.network_id,
203
962
            code: row.code,
204
962
            unit_id: match row.unit_id.len() {
205
352
                0 => None,
206
610
                _ => Some(row.unit_id),
207
            },
208
962
            unit_code: match row.unit_code.len() {
209
352
                0 => None,
210
610
                _ => Some(row.unit_code),
211
            },
212
962
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
213
962
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
214
962
            host_uri: row.host_uri,
215
962
            name: row.name,
216
962
            info: serde_json::from_str(row.info.as_str())?,
217
        }))
218
2276
    }
219

            
220
5120
    async fn add(&self, network: &Network) -> Result<(), Box<dyn StdError>> {
221
5120
        let unit_id = match network.unit_id.as_deref() {
222
1124
            None => quote(""),
223
3996
            Some(value) => quote(value),
224
        };
225
5120
        let unit_code = match network.unit_code.as_deref() {
226
1124
            None => quote(""),
227
3996
            Some(value) => quote(value),
228
        };
229
5120
        let info = match serde_json::to_string(&network.info) {
230
            Err(_) => quote("{}"),
231
5120
            Ok(value) => quote(value.as_str()),
232
        };
233
5120
        let values = vec![
234
5120
            quote(network.network_id.as_str()),
235
5120
            quote(network.code.as_str()),
236
5120
            unit_id,
237
5120
            unit_code,
238
5120
            network.created_at.timestamp_millis().to_string(),
239
5120
            network.modified_at.timestamp_millis().to_string(),
240
5120
            quote(network.host_uri.as_str()),
241
5120
            quote(network.name.as_str()),
242
5120
            info,
243
5120
        ];
244
5120
        let sql = SqlBuilder::insert_into(TABLE_NAME)
245
5120
            .fields(FIELDS)
246
5120
            .values(&values)
247
5120
            .sql()?;
248
5120
        let _ = sqlx::query(sql.as_str())
249
5120
            .execute(self.conn.as_ref())
250
5120
            .await?;
251
5112
        Ok(())
252
10240
    }
253

            
254
64
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
255
64
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
256
64
        let _ = sqlx::query(sql.as_str())
257
64
            .execute(self.conn.as_ref())
258
64
            .await?;
259
64
        Ok(())
260
128
    }
261

            
262
    async fn update(
263
        &self,
264
        cond: &UpdateQueryCond,
265
        updates: &Updates,
266
34
    ) -> Result<(), Box<dyn StdError>> {
267
34
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
268
        {
269
2
            None => return Ok(()),
270
32
            Some(builder) => builder.sql()?,
271
        };
272
32
        let _ = sqlx::query(sql.as_str())
273
32
            .execute(self.conn.as_ref())
274
32
            .await?;
275
32
        Ok(())
276
68
    }
277
}
278

            
279
impl DbCursor {
280
    /// To create the cursor instance.
281
274
    pub fn new() -> Self {
282
274
        DbCursor { offset: 0 }
283
274
    }
284
}
285

            
286
#[async_trait]
287
impl Cursor for DbCursor {
288
7096
    async fn try_next(&mut self) -> Result<Option<Network>, Box<dyn StdError>> {
289
7096
        self.offset += 1;
290
7096
        Ok(None)
291
14192
    }
292

            
293
3586
    fn offset(&self) -> u64 {
294
3586
        self.offset
295
3586
    }
296
}
297

            
298
/// Transforms query conditions to the SQL builder.
299
1202
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
300
1202
    if let Some(value) = cond.unit_id {
301
106
        match value {
302
28
            None => {
303
28
                builder.and_where_eq("unit_id", quote(""));
304
28
            }
305
78
            Some(value) => {
306
78
                builder.and_where_eq("unit_id", quote(value));
307
78
            }
308
        }
309
1096
    }
310
1202
    if let Some(value) = cond.network_id {
311
1110
        builder.and_where_eq("network_id", quote(value));
312
1110
    }
313
1202
    if let Some(value) = cond.code {
314
60
        builder.and_where_eq("code", quote(value));
315
1142
    }
316
1202
    builder
317
1202
}
318

            
319
/// Transforms query conditions to the SQL builder.
320
416
fn build_list_where<'a>(
321
416
    builder: &'a mut SqlBuilder,
322
416
    cond: &ListQueryCond<'a>,
323
416
) -> &'a mut SqlBuilder {
324
416
    if let Some(value) = cond.unit_id {
325
188
        match value {
326
36
            None => {
327
36
                builder.and_where_eq("unit_id", quote(""));
328
36
            }
329
152
            Some(value) => {
330
152
                builder.and_where_eq("unit_id", quote(value));
331
152
            }
332
        }
333
228
    }
334
416
    if let Some(value) = cond.network_id {
335
12
        builder.and_where_eq("network_id", quote(value));
336
404
    }
337
416
    if let Some(value) = cond.code {
338
48
        builder.and_where_eq("code", quote(value));
339
368
    }
340
416
    if let Some(value) = cond.code_contains {
341
118
        build_where_like(builder, "code", value.to_lowercase().as_str());
342
298
    }
343
416
    if let Some(value) = cond.name_contains {
344
16
        build_where_like(builder, "name", value.to_lowercase().as_str());
345
400
    }
346
416
    builder
347
416
}
348

            
349
/// Transforms model options to the SQL builder.
350
324
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
351
324
    if let Some(value) = opts.limit {
352
162
        if value > 0 {
353
160
            builder.limit(value);
354
160
        }
355
162
    }
356
324
    if let Some(value) = opts.offset {
357
324
        match opts.limit {
358
162
            None => builder.limit(-1).offset(value),
359
2
            Some(0) => builder.limit(-1).offset(value),
360
160
            _ => builder.offset(value),
361
        };
362
    }
363
324
    builder
364
324
}
365

            
366
/// Transforms model options to the SQL builder.
367
324
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
368
324
    if let Some(sort_cond) = opts.sort.as_ref() {
369
244
        for cond in sort_cond.iter() {
370
244
            let key = match cond.key {
371
16
                SortKey::CreatedAt => "created_at",
372
12
                SortKey::ModifiedAt => "modified_at",
373
200
                SortKey::Code => "code",
374
16
                SortKey::Name => "name",
375
            };
376
244
            builder.order_by(key, !cond.asc);
377
        }
378
82
    }
379
324
    builder
380
324
}
381

            
382
/// Transforms query conditions and the model object to the SQL builder.
383
34
fn build_update_where<'a>(
384
34
    builder: &'a mut SqlBuilder,
385
34
    cond: &UpdateQueryCond<'a>,
386
34
    updates: &Updates,
387
34
) -> Option<&'a mut SqlBuilder> {
388
34
    let mut count = 0;
389
34
    if let Some(value) = updates.modified_at.as_ref() {
390
32
        builder.set("modified_at", value.timestamp_millis());
391
32
        count += 1;
392
32
    }
393
34
    if let Some(value) = updates.host_uri.as_ref() {
394
16
        builder.set("host_uri", quote(value));
395
16
        count += 1;
396
18
    }
397
34
    if let Some(value) = updates.name.as_ref() {
398
28
        builder.set("name", quote(value));
399
28
        count += 1;
400
28
    }
401
34
    if let Some(value) = updates.info {
402
28
        match serde_json::to_string(value) {
403
            Err(_) => {
404
                builder.set("info", quote("{}"));
405
            }
406
28
            Ok(value) => {
407
28
                builder.set("info", quote(value));
408
28
            }
409
        }
410
28
        count += 1;
411
6
    }
412
34
    if count == 0 {
413
2
        return None;
414
32
    }
415
32

            
416
32
    builder.and_where_eq("network_id", quote(cond.network_id));
417
32
    Some(builder)
418
34
}