1
use std::{borrow::Cow, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeDelta, Utc};
5
use log::error;
6
use oxide_auth::primitives::{
7
    grant::{Extensions, Grant},
8
    issuer::{IssuedToken, RefreshedToken, TokenType},
9
    registrar::{BoundClient, ClientUrl, ExactUrl, PreGrant, RegisteredUrl, RegistrarError},
10
    scope::Scope,
11
};
12
use oxide_auth_async::primitives::{Authorizer, Issuer, Registrar};
13

            
14
use sylvia_iot_corelib::{err::E_UNKNOWN, strings};
15

            
16
use crate::models::{
17
    access_token::{self, AccessToken, QueryCond as AccessTokenQuery},
18
    authorization_code::{self, AuthorizationCode, QueryCond as AuthorizationCodeQuery},
19
    client::QueryCond,
20
    refresh_token::{self, QueryCond as RefreshTokenQuery, RefreshToken},
21
    Model,
22
};
23

            
24
#[derive(Clone)]
25
pub struct Primitive {
26
    model: Arc<dyn Model>,
27
}
28

            
29
impl Primitive {
30
27918
    pub fn new(model: Arc<dyn Model>) -> Self {
31
27918
        Primitive {
32
27918
            model: model.clone(),
33
27918
        }
34
27918
    }
35
}
36

            
37
#[async_trait]
38
impl Authorizer for Primitive {
39
274
    async fn authorize(&mut self, grant: Grant) -> Result<String, ()> {
40
        const FN_NAME: &'static str = "authorize";
41

            
42
274
        let scope = grant.scope.to_string();
43
274
        let code = AuthorizationCode {
44
274
            code: strings::random_id_sha(&grant.until, 4),
45
274
            expires_at: match TimeDelta::try_seconds(authorization_code::EXPIRES) {
46
                None => panic!("{}", E_UNKNOWN),
47
274
                Some(t) => Utc::now() + t,
48
274
            },
49
274
            redirect_uri: grant.redirect_uri.to_string(),
50
274
            scope: match scope.len() {
51
236
                0 => None,
52
38
                _ => Some(scope),
53
            },
54
274
            client_id: grant.client_id,
55
274
            user_id: grant.owner_id,
56
274
        };
57
274

            
58
548
        match self.model.authorization_code().add(&code).await {
59
            Err(e) => {
60
                error!("[{}] add authorization code error: {}", FN_NAME, e);
61
                Err(())
62
            }
63
274
            Ok(()) => Ok(code.code),
64
        }
65
548
    }
66

            
67
254
    async fn extract(&mut self, code: &str) -> Result<Option<Grant>, ()> {
68
        const FN_NAME: &'static str = "extract";
69

            
70
508
        let auth_code = match self.model.authorization_code().get(code).await {
71
            Err(_) => return Err(()),
72
254
            Ok(code) => match code {
73
4
                None => return Ok(None),
74
250
                Some(code) => code,
75
250
            },
76
250
        };
77
250
        {
78
250
            let query = AuthorizationCodeQuery {
79
250
                code: Some(code),
80
250
                ..Default::default()
81
250
            };
82
499
            if let Err(e) = self.model.authorization_code().del(&query).await {
83
                error!("[{}] delete authorization code error: {}", FN_NAME, e);
84
                return Err(());
85
250
            }
86
250
        }
87
250
        if auth_code.expires_at < Utc::now() {
88
            return Ok(None);
89
250
        }
90
250

            
91
250
        Ok(Some(Grant {
92
250
            owner_id: auth_code.user_id,
93
250
            client_id: auth_code.client_id,
94
250
            scope: match auth_code.scope {
95
224
                None => "".parse().unwrap(),
96
26
                Some(scope) => match scope.as_str().parse() {
97
                    Err(e) => {
98
                        error!("[{}] parse authorization code scope error: {}", FN_NAME, e);
99
                        return Err(());
100
                    }
101
26
                    Ok(scope) => scope,
102
                },
103
            },
104
250
            redirect_uri: match auth_code.redirect_uri.parse() {
105
                Err(e) => {
106
                    error!(
107
                        "[{}] parse authorization code redirect_uri error: {}",
108
                        FN_NAME, e
109
                    );
110
                    return Err(());
111
                }
112
250
                Ok(uri) => uri,
113
250
            },
114
250
            until: auth_code.expires_at,
115
250
            extensions: Extensions::new(),
116
        }))
117
508
    }
118
}
119

            
120
#[async_trait]
121
impl Issuer for Primitive {
122
262
    async fn issue(&mut self, grant: Grant) -> Result<IssuedToken, ()> {
123
        const FN_NAME: &'static str = "issue";
124

            
125
262
        let now = Utc::now();
126
262
        let refresh_token = strings::random_id_sha(&now, 8);
127
262
        let access = AccessToken {
128
262
            access_token: strings::random_id_sha(&now, 8),
129
262
            refresh_token: Some(refresh_token.clone()),
130
262
            expires_at: match TimeDelta::try_seconds(access_token::EXPIRES) {
131
                None => panic!("{}", E_UNKNOWN),
132
262
                Some(t) => now + t,
133
262
            },
134
262
            scope: Some(grant.scope.to_string()),
135
262
            client_id: grant.client_id.clone(),
136
262
            redirect_uri: grant.redirect_uri.to_string(),
137
262
            user_id: grant.owner_id.clone(),
138
        };
139
262
        let refresh = RefreshToken {
140
262
            refresh_token,
141
262
            expires_at: match TimeDelta::try_seconds(refresh_token::EXPIRES) {
142
                None => panic!("{}", E_UNKNOWN),
143
262
                Some(t) => now + t,
144
262
            },
145
262
            scope: Some(grant.scope.to_string()),
146
262
            client_id: grant.client_id,
147
262
            redirect_uri: grant.redirect_uri.to_string(),
148
262
            user_id: grant.owner_id,
149
        };
150

            
151
524
        if let Err(e) = self.model.access_token().add(&access).await {
152
            error!("[{}] add access token error: {}", FN_NAME, e);
153
            return Err(());
154
262
        }
155
524
        if let Err(e) = self.model.refresh_token().add(&refresh).await {
156
            error!("[{}] add refresh token error: {}", FN_NAME, e);
157
            return Err(());
158
262
        }
159
262
        Ok(IssuedToken {
160
262
            token: access.access_token,
161
262
            refresh: Some(refresh.refresh_token),
162
262
            until: access.expires_at,
163
262
            token_type: TokenType::Bearer,
164
262
        })
165
524
    }
166

            
167
12
    async fn refresh(&mut self, token: &str, grant: Grant) -> Result<RefreshedToken, ()> {
168
        const FN_NAME: &'static str = "refresh";
169

            
170
12
        let query = AccessTokenQuery {
171
12
            refresh_token: Some(token),
172
12
            ..Default::default()
173
12
        };
174
24
        if let Err(e) = self.model.access_token().del(&query).await {
175
            error!("[{}] delete access token error: {}", FN_NAME, e);
176
            return Err(());
177
12
        }
178
12
        let query = RefreshTokenQuery {
179
12
            refresh_token: Some(token),
180
12
            ..Default::default()
181
12
        };
182
24
        if let Err(e) = self.model.refresh_token().del(&query).await {
183
            error!("[{}] delete refresh token error: {}", FN_NAME, e);
184
            return Err(());
185
12
        }
186
12

            
187
48
        match self.issue(grant).await {
188
            Err(_) => Err(()),
189
12
            Ok(token) => Ok(RefreshedToken {
190
12
                token: token.token,
191
12
                refresh: token.refresh,
192
12
                until: token.until,
193
12
                token_type: token.token_type,
194
12
            }),
195
        }
196
24
    }
197

            
198
498
    async fn recover_token(&mut self, token: &str) -> Result<Option<Grant>, ()> {
199
        const FN_NAME: &'static str = "recover_token";
200

            
201
996
        let access = match self.model.access_token().get(token).await {
202
            Err(e) => {
203
                error!("[{}] get access token error: {}", FN_NAME, e);
204
                return Err(());
205
            }
206
498
            Ok(token) => match token {
207
42
                None => return Ok(None),
208
456
                Some(token) => token,
209
456
            },
210
456
        };
211
456
        if access.expires_at < Utc::now() {
212
            return Ok(None);
213
456
        }
214
456

            
215
456
        Ok(Some(Grant {
216
456
            owner_id: access.user_id,
217
456
            client_id: access.client_id,
218
456
            scope: match access.scope {
219
2
                None => "".parse().unwrap(),
220
454
                Some(scope) => match scope.as_str().parse() {
221
                    Err(e) => {
222
                        error!("[{}] parse access token scope error: {}", FN_NAME, e);
223
                        return Err(());
224
                    }
225
454
                    Ok(scope) => scope,
226
                },
227
            },
228
456
            redirect_uri: match access.redirect_uri.parse() {
229
                Err(e) => {
230
                    error!("[{}] parse access token redirect_uri error: {}", FN_NAME, e);
231
                    return Err(());
232
                }
233
456
                Ok(uri) => uri,
234
456
            },
235
456
            until: access.expires_at,
236
456
            extensions: Extensions::new(),
237
        }))
238
996
    }
239

            
240
30
    async fn recover_refresh(&mut self, token: &str) -> Result<Option<Grant>, ()> {
241
        const FN_NAME: &'static str = "recover_refresh";
242

            
243
60
        let refresh = match self.model.refresh_token().get(token).await {
244
            Err(e) => {
245
                error!("[{}] get refresh token error: {}", FN_NAME, e);
246
                return Err(());
247
            }
248
30
            Ok(token) => match token {
249
4
                None => return Ok(None),
250
26
                Some(token) => token,
251
26
            },
252
26
        };
253
26
        if refresh.expires_at < Utc::now() {
254
            return Ok(None);
255
26
        }
256
26

            
257
26
        Ok(Some(Grant {
258
26
            owner_id: refresh.user_id,
259
26
            client_id: refresh.client_id,
260
26
            scope: match refresh.scope {
261
                None => "".parse().unwrap(),
262
26
                Some(scope) => match scope.as_str().parse() {
263
                    Err(e) => {
264
                        error!("[{}] parse access token scope error: {}", FN_NAME, e);
265
                        return Err(());
266
                    }
267
26
                    Ok(scope) => scope,
268
                },
269
            },
270
26
            redirect_uri: match refresh.redirect_uri.parse() {
271
                Err(e) => {
272
                    error!("[{}] parse access token redirect_uri error: {}", FN_NAME, e);
273
                    return Err(());
274
                }
275
26
                Ok(uri) => uri,
276
26
            },
277
26
            until: refresh.expires_at,
278
26
            extensions: Extensions::new(),
279
        }))
280
60
    }
281
}
282

            
283
#[async_trait]
284
impl Registrar for Primitive {
285
    async fn bound_redirect<'a>(
286
        &self,
287
        bound: ClientUrl<'a>,
288
306
    ) -> Result<BoundClient<'a>, RegistrarError> {
289
        const FN_NAME: &'static str = "bound_redirect";
290

            
291
306
        let cond = QueryCond {
292
306
            client_id: Some(&bound.client_id),
293
306
            ..Default::default()
294
306
        };
295
612
        let redirect_uris = match self.model.client().get(&cond).await {
296
            Err(e) => {
297
                error!("[{}] get client error: {}", FN_NAME, e);
298
                return Err(RegistrarError::PrimitiveError);
299
            }
300
306
            Ok(client) => match client {
301
4
                None => return Err(RegistrarError::Unspecified),
302
302
                Some(client) => client.redirect_uris,
303
            },
304
        };
305

            
306
302
        let redirect_uri = match bound.redirect_uri {
307
            None => match redirect_uris.len() {
308
                0 => return Err(RegistrarError::Unspecified),
309
                _ => redirect_uris.get(0).unwrap(),
310
            },
311
302
            Some(url) => match redirect_uris
312
302
                .iter()
313
302
                .find(|uri| uri.as_str() == url.as_str())
314
            {
315
                None => return Err(RegistrarError::Unspecified),
316
302
                Some(uri) => uri,
317
            },
318
        };
319
302
        let redirect_uri = match ExactUrl::new(redirect_uri.clone()) {
320
            Err(_) => return Err(RegistrarError::Unspecified),
321
302
            Ok(url) => url,
322
302
        };
323
302

            
324
302
        Ok(BoundClient {
325
302
            client_id: bound.client_id,
326
302
            redirect_uri: Cow::Owned(RegisteredUrl::Exact(redirect_uri)),
327
302
        })
328
612
    }
329

            
330
    async fn negotiate<'a>(
331
        &self,
332
        bound: BoundClient<'a>,
333
        scope: Option<Scope>,
334
298
    ) -> Result<PreGrant, RegistrarError> {
335
        const FN_NAME: &'static str = "negotiate";
336

            
337
298
        let cond = QueryCond {
338
298
            client_id: Some(&bound.client_id),
339
298
            ..Default::default()
340
298
        };
341
596
        let client = match self.model.client().get(&cond).await {
342
            Err(e) => {
343
                return {
344
                    error!("[{}] get client error: {}", FN_NAME, e);
345
                    Err(RegistrarError::PrimitiveError)
346
                }
347
            }
348
298
            Ok(client) => match client {
349
                None => return Err(RegistrarError::Unspecified),
350
298
                Some(client) => client,
351
298
            },
352
298
        };
353
298

            
354
298
        if client.scopes.len() > 0 {
355
56
            match scope {
356
4
                None => return Err(RegistrarError::Unspecified),
357
52
                Some(scope) => {
358
52
                    let client_scopes = match client.scopes.join(" ").parse::<Scope>() {
359
                        Err(e) => {
360
                            error!("[{}] parse client scope error: {}", FN_NAME, e);
361
                            return Err(RegistrarError::PrimitiveError);
362
                        }
363
52
                        Ok(scopes) => scopes,
364
52
                    };
365
52
                    if !scope.allow_access(&client_scopes) {
366
4
                        return Err(RegistrarError::Unspecified);
367
48
                    }
368
                }
369
            }
370
242
        }
371

            
372
        Ok(PreGrant {
373
290
            client_id: bound.client_id.into_owned(),
374
290
            redirect_uri: bound.redirect_uri.into_owned(),
375
290
            scope: match client.scopes.join(" ").parse() {
376
                Err(e) => {
377
                    error!("[{}] parse client scope error: {}", FN_NAME, e);
378
                    return Err(RegistrarError::PrimitiveError);
379
                }
380
290
                Ok(scopes) => scopes,
381
            },
382
        })
383
596
    }
384

            
385
    async fn check(
386
        &self,
387
        client_id: &str,
388
        passphrase: Option<&[u8]>,
389
288
    ) -> Result<(), RegistrarError> {
390
        const FN_NAME: &'static str = "check";
391

            
392
288
        let cond = QueryCond {
393
288
            client_id: Some(client_id),
394
288
            ..Default::default()
395
288
        };
396
576
        let client = match self.model.client().get(&cond).await {
397
            Err(e) => {
398
                error!("[{}] get client error: {}", FN_NAME, e);
399
                return Err(RegistrarError::PrimitiveError);
400
            }
401
288
            Ok(client) => match client {
402
4
                None => return Err(RegistrarError::Unspecified),
403
284
                Some(client) => client,
404
284
            },
405
284
        };
406
284

            
407
284
        match (passphrase, client.client_secret) {
408
238
            (None, None) => Ok(()),
409
32
            (Some(passphrase), Some(client_secret)) => {
410
32
                match passphrase == client_secret.as_bytes() {
411
32
                    true => Ok(()),
412
                    false => Err(RegistrarError::Unspecified),
413
                }
414
            }
415
14
            _ => Err(RegistrarError::Unspecified),
416
        }
417
576
    }
418
}