1
use std::{error::Error as StdError, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeZone, Utc};
5
use futures::TryStreamExt;
6
use sql_builder::{SqlBuilder, quote};
7
use sqlx::SqlitePool;
8

            
9
use super::{
10
    super::network::{
11
        Cursor, ListOptions, ListQueryCond, Network, NetworkModel, QueryCond, SortKey,
12
        UpdateQueryCond, Updates,
13
    },
14
    build_where_like,
15
};
16

            
17
/// Model instance.
18
pub struct Model {
19
    /// The associated database connection.
20
    conn: Arc<SqlitePool>,
21
}
22

            
23
/// Cursor instance.
24
///
25
/// The SQLite implementation uses the original list options and the progress offset.
26
pub struct DbCursor {
27
    offset: u64,
28
}
29

            
30
/// SQLite schema.
31
#[derive(sqlx::FromRow)]
32
struct Schema {
33
    network_id: String,
34
    code: String,
35
    /// use empty string as NULL because duplicate `(unit_id=NULL,code="code")` is allowed.
36
    unit_id: String,
37
    /// use empty string as NULL.
38
    unit_code: String,
39
    /// i64 as time tick from Epoch in milliseconds.
40
    created_at: i64,
41
    /// i64 as time tick from Epoch in milliseconds.
42
    modified_at: i64,
43
    host_uri: String,
44
    name: String,
45
    info: String,
46
}
47

            
48
/// Use "COUNT(*)" instead of "COUNT(fields...)" to simplify the implementation.
49
#[derive(sqlx::FromRow)]
50
struct CountSchema {
51
    #[sqlx(rename = "COUNT(*)")]
52
    count: i64,
53
}
54

            
55
const TABLE_NAME: &'static str = "network";
56
const FIELDS: &'static [&'static str] = &[
57
    "network_id",
58
    "code",
59
    "unit_id",
60
    "unit_code",
61
    "created_at",
62
    "modified_at",
63
    "host_uri",
64
    "name",
65
    "info",
66
];
67
const TABLE_INIT_SQL: &'static str = "\
68
    CREATE TABLE IF NOT EXISTS network (\
69
    network_id TEXT NOT NULL UNIQUE,\
70
    code TEXT NOT NULL,\
71
    unit_id TEXT NOT NULL,\
72
    unit_code TEXT NOT NULL,\
73
    created_at INTEGER NOT NULL,\
74
    modified_at INTEGER NOT NULL,\
75
    host_uri TEXT NOT NULL,\
76
    name TEXT NOT NULL,\
77
    info TEXT,\
78
    UNIQUE (unit_id,code),\
79
    PRIMARY KEY (network_id))";
80

            
81
impl Model {
82
    /// To create the model instance with a database connection.
83
56
    pub async fn new(conn: Arc<SqlitePool>) -> Result<Self, Box<dyn StdError>> {
84
56
        let model = Model { conn };
85
56
        model.init().await?;
86
56
        Ok(model)
87
56
    }
88
}
89

            
90
#[async_trait]
91
impl NetworkModel for Model {
92
96
    async fn init(&self) -> Result<(), Box<dyn StdError>> {
93
        let _ = sqlx::query(TABLE_INIT_SQL)
94
            .execute(self.conn.as_ref())
95
            .await?;
96
        Ok(())
97
96
    }
98

            
99
92
    async fn count(&self, cond: &ListQueryCond) -> Result<u64, Box<dyn StdError>> {
100
        let sql = build_list_where(SqlBuilder::select_from(TABLE_NAME).count("*"), &cond).sql()?;
101

            
102
        let result: Result<CountSchema, sqlx::Error> = sqlx::query_as(sql.as_str())
103
            .fetch_one(self.conn.as_ref())
104
            .await;
105

            
106
        let row = match result {
107
            Err(e) => return Err(Box::new(e)),
108
            Ok(row) => row,
109
        };
110
        Ok(row.count as u64)
111
92
    }
112

            
113
    async fn list(
114
        &self,
115
        opts: &ListOptions,
116
        cursor: Option<Box<dyn Cursor>>,
117
342
    ) -> Result<(Vec<Network>, Option<Box<dyn Cursor>>), Box<dyn StdError>> {
118
        let mut cursor = match cursor {
119
            None => Box::new(DbCursor::new()),
120
            Some(cursor) => cursor,
121
        };
122

            
123
        let mut opts = ListOptions { ..*opts };
124
        if let Some(offset) = opts.offset {
125
            opts.offset = Some(offset + cursor.offset());
126
        } else {
127
            opts.offset = Some(cursor.offset());
128
        }
129
        let opts_limit = opts.limit;
130
        if let Some(limit) = opts_limit {
131
            if limit > 0 {
132
                if cursor.offset() >= limit {
133
                    return Ok((vec![], None));
134
                }
135
                opts.limit = Some(limit - cursor.offset());
136
            }
137
        }
138
        let mut builder = SqlBuilder::select_from(TABLE_NAME);
139
        build_limit_offset(&mut builder, &opts);
140
        build_sort(&mut builder, &opts);
141
        let sql = build_list_where(&mut builder, opts.cond).sql()?;
142

            
143
        let mut rows = sqlx::query_as::<_, Schema>(sql.as_str()).fetch(self.conn.as_ref());
144

            
145
        let mut count: u64 = 0;
146
        let mut list = vec![];
147
        while let Some(row) = rows.try_next().await? {
148
            let _ = cursor.as_mut().try_next().await?;
149
            list.push(Network {
150
                network_id: row.network_id,
151
                code: row.code,
152
                unit_id: match row.unit_id.len() {
153
                    0 => None,
154
                    _ => Some(row.unit_id),
155
                },
156
                unit_code: match row.unit_code.len() {
157
                    0 => None,
158
                    _ => Some(row.unit_code),
159
                },
160
                created_at: Utc.timestamp_nanos(row.created_at * 1000000),
161
                modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
162
                host_uri: row.host_uri,
163
                name: row.name,
164
                info: serde_json::from_str(row.info.as_str())?,
165
            });
166
            if let Some(limit) = opts_limit {
167
                if limit > 0 && cursor.offset() >= limit {
168
                    if let Some(cursor_max) = opts.cursor_max {
169
                        if (count + 1) >= cursor_max {
170
                            return Ok((list, Some(cursor)));
171
                        }
172
                    }
173
                    return Ok((list, None));
174
                }
175
            }
176
            if let Some(cursor_max) = opts.cursor_max {
177
                count += 1;
178
                if count >= cursor_max {
179
                    return Ok((list, Some(cursor)));
180
                }
181
            }
182
        }
183
        Ok((list, None))
184
342
    }
185

            
186
1138
    async fn get(&self, cond: &QueryCond) -> Result<Option<Network>, Box<dyn StdError>> {
187
        let sql = build_where(SqlBuilder::select_from(TABLE_NAME).fields(FIELDS), &cond).sql()?;
188

            
189
        let result: Result<Schema, sqlx::Error> = sqlx::query_as(sql.as_str())
190
            .fetch_one(self.conn.as_ref())
191
            .await;
192

            
193
        let row = match result {
194
            Err(e) => match e {
195
                sqlx::Error::RowNotFound => return Ok(None),
196
                _ => return Err(Box::new(e)),
197
            },
198
            Ok(row) => row,
199
        };
200

            
201
        Ok(Some(Network {
202
            network_id: row.network_id,
203
            code: row.code,
204
            unit_id: match row.unit_id.len() {
205
                0 => None,
206
                _ => Some(row.unit_id),
207
            },
208
            unit_code: match row.unit_code.len() {
209
                0 => None,
210
                _ => Some(row.unit_code),
211
            },
212
            created_at: Utc.timestamp_nanos(row.created_at * 1000000),
213
            modified_at: Utc.timestamp_nanos(row.modified_at * 1000000),
214
            host_uri: row.host_uri,
215
            name: row.name,
216
            info: serde_json::from_str(row.info.as_str())?,
217
        }))
218
1138
    }
219

            
220
5120
    async fn add(&self, network: &Network) -> Result<(), Box<dyn StdError>> {
221
        let unit_id = match network.unit_id.as_deref() {
222
            None => quote(""),
223
            Some(value) => quote(value),
224
        };
225
        let unit_code = match network.unit_code.as_deref() {
226
            None => quote(""),
227
            Some(value) => quote(value),
228
        };
229
        let info = match serde_json::to_string(&network.info) {
230
            Err(_) => quote("{}"),
231
            Ok(value) => quote(value.as_str()),
232
        };
233
        let values = vec![
234
            quote(network.network_id.as_str()),
235
            quote(network.code.as_str()),
236
            unit_id,
237
            unit_code,
238
            network.created_at.timestamp_millis().to_string(),
239
            network.modified_at.timestamp_millis().to_string(),
240
            quote(network.host_uri.as_str()),
241
            quote(network.name.as_str()),
242
            info,
243
        ];
244
        let sql = SqlBuilder::insert_into(TABLE_NAME)
245
            .fields(FIELDS)
246
            .values(&values)
247
            .sql()?;
248
        let _ = sqlx::query(sql.as_str())
249
            .execute(self.conn.as_ref())
250
            .await?;
251
        Ok(())
252
5120
    }
253

            
254
64
    async fn del(&self, cond: &QueryCond) -> Result<(), Box<dyn StdError>> {
255
        let sql = build_where(&mut SqlBuilder::delete_from(TABLE_NAME), cond).sql()?;
256
        let _ = sqlx::query(sql.as_str())
257
            .execute(self.conn.as_ref())
258
            .await?;
259
        Ok(())
260
64
    }
261

            
262
    async fn update(
263
        &self,
264
        cond: &UpdateQueryCond,
265
        updates: &Updates,
266
34
    ) -> Result<(), Box<dyn StdError>> {
267
        let sql = match build_update_where(&mut SqlBuilder::update_table(TABLE_NAME), cond, updates)
268
        {
269
            None => return Ok(()),
270
            Some(builder) => builder.sql()?,
271
        };
272
        let _ = sqlx::query(sql.as_str())
273
            .execute(self.conn.as_ref())
274
            .await?;
275
        Ok(())
276
34
    }
277
}
278

            
279
impl DbCursor {
280
    /// To create the cursor instance.
281
274
    pub fn new() -> Self {
282
274
        DbCursor { offset: 0 }
283
274
    }
284
}
285

            
286
#[async_trait]
287
impl Cursor for DbCursor {
288
7096
    async fn try_next(&mut self) -> Result<Option<Network>, Box<dyn StdError>> {
289
        self.offset += 1;
290
        Ok(None)
291
7096
    }
292

            
293
3586
    fn offset(&self) -> u64 {
294
3586
        self.offset
295
3586
    }
296
}
297

            
298
/// Transforms query conditions to the SQL builder.
299
1202
fn build_where<'a>(builder: &'a mut SqlBuilder, cond: &QueryCond<'a>) -> &'a mut SqlBuilder {
300
1202
    if let Some(value) = cond.unit_id {
301
106
        match value {
302
28
            None => {
303
28
                builder.and_where_eq("unit_id", quote(""));
304
28
            }
305
78
            Some(value) => {
306
78
                builder.and_where_eq("unit_id", quote(value));
307
78
            }
308
        }
309
1096
    }
310
1202
    if let Some(value) = cond.network_id {
311
1110
        builder.and_where_eq("network_id", quote(value));
312
1110
    }
313
1202
    if let Some(value) = cond.code {
314
60
        builder.and_where_eq("code", quote(value));
315
1142
    }
316
1202
    builder
317
1202
}
318

            
319
/// Transforms query conditions to the SQL builder.
320
416
fn build_list_where<'a>(
321
416
    builder: &'a mut SqlBuilder,
322
416
    cond: &ListQueryCond<'a>,
323
416
) -> &'a mut SqlBuilder {
324
416
    if let Some(value) = cond.unit_id {
325
188
        match value {
326
36
            None => {
327
36
                builder.and_where_eq("unit_id", quote(""));
328
36
            }
329
152
            Some(value) => {
330
152
                builder.and_where_eq("unit_id", quote(value));
331
152
            }
332
        }
333
228
    }
334
416
    if let Some(value) = cond.network_id {
335
12
        builder.and_where_eq("network_id", quote(value));
336
404
    }
337
416
    if let Some(value) = cond.code {
338
48
        builder.and_where_eq("code", quote(value));
339
368
    }
340
416
    if let Some(value) = cond.code_contains {
341
118
        build_where_like(builder, "code", value.to_lowercase().as_str());
342
298
    }
343
416
    if let Some(value) = cond.name_contains {
344
16
        build_where_like(builder, "name", value.to_lowercase().as_str());
345
400
    }
346
416
    builder
347
416
}
348

            
349
/// Transforms model options to the SQL builder.
350
324
fn build_limit_offset<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
351
324
    if let Some(value) = opts.limit {
352
162
        if value > 0 {
353
160
            builder.limit(value);
354
160
        }
355
162
    }
356
324
    if let Some(value) = opts.offset {
357
324
        match opts.limit {
358
162
            None => builder.limit(-1).offset(value),
359
2
            Some(0) => builder.limit(-1).offset(value),
360
160
            _ => builder.offset(value),
361
        };
362
    }
363
324
    builder
364
324
}
365

            
366
/// Transforms model options to the SQL builder.
367
324
fn build_sort<'a>(builder: &'a mut SqlBuilder, opts: &ListOptions) -> &'a mut SqlBuilder {
368
324
    if let Some(sort_cond) = opts.sort.as_ref() {
369
244
        for cond in sort_cond.iter() {
370
244
            let key = match cond.key {
371
16
                SortKey::CreatedAt => "created_at",
372
12
                SortKey::ModifiedAt => "modified_at",
373
200
                SortKey::Code => "code",
374
16
                SortKey::Name => "name",
375
            };
376
244
            builder.order_by(key, !cond.asc);
377
        }
378
82
    }
379
324
    builder
380
324
}
381

            
382
/// Transforms query conditions and the model object to the SQL builder.
383
34
fn build_update_where<'a>(
384
34
    builder: &'a mut SqlBuilder,
385
34
    cond: &UpdateQueryCond<'a>,
386
34
    updates: &Updates,
387
34
) -> Option<&'a mut SqlBuilder> {
388
34
    let mut count = 0;
389
34
    if let Some(value) = updates.modified_at.as_ref() {
390
32
        builder.set("modified_at", value.timestamp_millis());
391
32
        count += 1;
392
32
    }
393
34
    if let Some(value) = updates.host_uri.as_ref() {
394
16
        builder.set("host_uri", quote(value));
395
16
        count += 1;
396
18
    }
397
34
    if let Some(value) = updates.name.as_ref() {
398
28
        builder.set("name", quote(value));
399
28
        count += 1;
400
28
    }
401
34
    if let Some(value) = updates.info {
402
28
        match serde_json::to_string(value) {
403
            Err(_) => {
404
                builder.set("info", quote("{}"));
405
            }
406
28
            Ok(value) => {
407
28
                builder.set("info", quote(value));
408
28
            }
409
        }
410
28
        count += 1;
411
6
    }
412
34
    if count == 0 {
413
2
        return None;
414
32
    }
415

            
416
32
    builder.and_where_eq("network_id", quote(cond.network_id));
417
32
    Some(builder)
418
34
}