1
use std::{borrow::Cow, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeDelta, Utc};
5
use log::error;
6
use oxide_auth::primitives::{
7
    grant::{Extensions, Grant},
8
    issuer::{IssuedToken, RefreshedToken, TokenType},
9
    registrar::{BoundClient, ClientUrl, ExactUrl, PreGrant, RegisteredUrl, RegistrarError},
10
    scope::Scope,
11
};
12
use oxide_auth_async::primitives::{Authorizer, Issuer, Registrar};
13

            
14
use sylvia_iot_corelib::{err::E_UNKNOWN, strings};
15

            
16
use crate::models::{
17
    access_token::{self, AccessToken, QueryCond as AccessTokenQuery},
18
    authorization_code::{self, AuthorizationCode, QueryCond as AuthorizationCodeQuery},
19
    client::QueryCond,
20
    refresh_token::{self, QueryCond as RefreshTokenQuery, RefreshToken},
21
    Model,
22
};
23

            
24
#[derive(Clone)]
25
pub struct Primitive {
26
    model: Arc<dyn Model>,
27
}
28

            
29
impl Primitive {
30
55836
    pub fn new(model: Arc<dyn Model>) -> Self {
31
55836
        Primitive {
32
55836
            model: model.clone(),
33
55836
        }
34
55836
    }
35
}
36

            
37
#[async_trait]
38
impl Authorizer for Primitive {
39
548
    async fn authorize(&mut self, grant: Grant) -> Result<String, ()> {
40
        const FN_NAME: &'static str = "authorize";
41

            
42
548
        let scope = grant.scope.to_string();
43
548
        let code = AuthorizationCode {
44
548
            code: strings::random_id_sha(&grant.until, 4),
45
548
            expires_at: match TimeDelta::try_seconds(authorization_code::EXPIRES) {
46
                None => panic!("{}", E_UNKNOWN),
47
548
                Some(t) => Utc::now() + t,
48
548
            },
49
548
            redirect_uri: grant.redirect_uri.to_string(),
50
548
            scope: match scope.len() {
51
472
                0 => None,
52
76
                _ => Some(scope),
53
            },
54
548
            client_id: grant.client_id,
55
548
            user_id: grant.owner_id,
56
548
        };
57
548

            
58
548
        match self.model.authorization_code().add(&code).await {
59
            Err(e) => {
60
                error!("[{}] add authorization code error: {}", FN_NAME, e);
61
                Err(())
62
            }
63
548
            Ok(()) => Ok(code.code),
64
        }
65
1096
    }
66

            
67
508
    async fn extract(&mut self, code: &str) -> Result<Option<Grant>, ()> {
68
        const FN_NAME: &'static str = "extract";
69

            
70
508
        let auth_code = match self.model.authorization_code().get(code).await {
71
            Err(_) => return Err(()),
72
508
            Ok(code) => match code {
73
8
                None => return Ok(None),
74
500
                Some(code) => code,
75
500
            },
76
500
        };
77
500
        {
78
500
            let query = AuthorizationCodeQuery {
79
500
                code: Some(code),
80
500
                ..Default::default()
81
500
            };
82
500
            if let Err(e) = self.model.authorization_code().del(&query).await {
83
                error!("[{}] delete authorization code error: {}", FN_NAME, e);
84
                return Err(());
85
500
            }
86
500
        }
87
500
        if auth_code.expires_at < Utc::now() {
88
            return Ok(None);
89
500
        }
90
500

            
91
500
        Ok(Some(Grant {
92
500
            owner_id: auth_code.user_id,
93
500
            client_id: auth_code.client_id,
94
500
            scope: match auth_code.scope {
95
448
                None => "".parse().unwrap(),
96
52
                Some(scope) => match scope.as_str().parse() {
97
                    Err(e) => {
98
                        error!("[{}] parse authorization code scope error: {}", FN_NAME, e);
99
                        return Err(());
100
                    }
101
52
                    Ok(scope) => scope,
102
                },
103
            },
104
500
            redirect_uri: match auth_code.redirect_uri.parse() {
105
                Err(e) => {
106
                    error!(
107
                        "[{}] parse authorization code redirect_uri error: {}",
108
                        FN_NAME, e
109
                    );
110
                    return Err(());
111
                }
112
500
                Ok(uri) => uri,
113
500
            },
114
500
            until: auth_code.expires_at,
115
500
            extensions: Extensions::new(),
116
        }))
117
1016
    }
118
}
119

            
120
#[async_trait]
121
impl Issuer for Primitive {
122
524
    async fn issue(&mut self, grant: Grant) -> Result<IssuedToken, ()> {
123
        const FN_NAME: &'static str = "issue";
124

            
125
524
        let now = Utc::now();
126
524
        let refresh_token = strings::random_id_sha(&now, 8);
127
524
        let access = AccessToken {
128
524
            access_token: strings::random_id_sha(&now, 8),
129
524
            refresh_token: Some(refresh_token.clone()),
130
524
            expires_at: match TimeDelta::try_seconds(access_token::EXPIRES) {
131
                None => panic!("{}", E_UNKNOWN),
132
524
                Some(t) => now + t,
133
524
            },
134
524
            scope: Some(grant.scope.to_string()),
135
524
            client_id: grant.client_id.clone(),
136
524
            redirect_uri: grant.redirect_uri.to_string(),
137
524
            user_id: grant.owner_id.clone(),
138
        };
139
524
        let refresh = RefreshToken {
140
524
            refresh_token,
141
524
            expires_at: match TimeDelta::try_seconds(refresh_token::EXPIRES) {
142
                None => panic!("{}", E_UNKNOWN),
143
524
                Some(t) => now + t,
144
524
            },
145
524
            scope: Some(grant.scope.to_string()),
146
524
            client_id: grant.client_id,
147
524
            redirect_uri: grant.redirect_uri.to_string(),
148
524
            user_id: grant.owner_id,
149
        };
150

            
151
524
        if let Err(e) = self.model.access_token().add(&access).await {
152
            error!("[{}] add access token error: {}", FN_NAME, e);
153
            return Err(());
154
524
        }
155
524
        if let Err(e) = self.model.refresh_token().add(&refresh).await {
156
            error!("[{}] add refresh token error: {}", FN_NAME, e);
157
            return Err(());
158
524
        }
159
524
        Ok(IssuedToken {
160
524
            token: access.access_token,
161
524
            refresh: Some(refresh.refresh_token),
162
524
            until: access.expires_at,
163
524
            token_type: TokenType::Bearer,
164
524
        })
165
1048
    }
166

            
167
24
    async fn refresh(&mut self, token: &str, grant: Grant) -> Result<RefreshedToken, ()> {
168
        const FN_NAME: &'static str = "refresh";
169

            
170
24
        let query = AccessTokenQuery {
171
24
            refresh_token: Some(token),
172
24
            ..Default::default()
173
24
        };
174
24
        if let Err(e) = self.model.access_token().del(&query).await {
175
            error!("[{}] delete access token error: {}", FN_NAME, e);
176
            return Err(());
177
24
        }
178
24
        let query = RefreshTokenQuery {
179
24
            refresh_token: Some(token),
180
24
            ..Default::default()
181
24
        };
182
24
        if let Err(e) = self.model.refresh_token().del(&query).await {
183
            error!("[{}] delete refresh token error: {}", FN_NAME, e);
184
            return Err(());
185
24
        }
186
24

            
187
24
        match self.issue(grant).await {
188
            Err(_) => Err(()),
189
24
            Ok(token) => Ok(RefreshedToken {
190
24
                token: token.token,
191
24
                refresh: token.refresh,
192
24
                until: token.until,
193
24
                token_type: token.token_type,
194
24
            }),
195
        }
196
48
    }
197

            
198
996
    async fn recover_token(&mut self, token: &str) -> Result<Option<Grant>, ()> {
199
        const FN_NAME: &'static str = "recover_token";
200

            
201
996
        let access = match self.model.access_token().get(token).await {
202
            Err(e) => {
203
                error!("[{}] get access token error: {}", FN_NAME, e);
204
                return Err(());
205
            }
206
996
            Ok(token) => match token {
207
84
                None => return Ok(None),
208
912
                Some(token) => token,
209
912
            },
210
912
        };
211
912
        if access.expires_at < Utc::now() {
212
            return Ok(None);
213
912
        }
214
912

            
215
912
        Ok(Some(Grant {
216
912
            owner_id: access.user_id,
217
912
            client_id: access.client_id,
218
912
            scope: match access.scope {
219
4
                None => "".parse().unwrap(),
220
908
                Some(scope) => match scope.as_str().parse() {
221
                    Err(e) => {
222
                        error!("[{}] parse access token scope error: {}", FN_NAME, e);
223
                        return Err(());
224
                    }
225
908
                    Ok(scope) => scope,
226
                },
227
            },
228
912
            redirect_uri: match access.redirect_uri.parse() {
229
                Err(e) => {
230
                    error!("[{}] parse access token redirect_uri error: {}", FN_NAME, e);
231
                    return Err(());
232
                }
233
912
                Ok(uri) => uri,
234
912
            },
235
912
            until: access.expires_at,
236
912
            extensions: Extensions::new(),
237
        }))
238
1992
    }
239

            
240
60
    async fn recover_refresh(&mut self, token: &str) -> Result<Option<Grant>, ()> {
241
        const FN_NAME: &'static str = "recover_refresh";
242

            
243
60
        let refresh = match self.model.refresh_token().get(token).await {
244
            Err(e) => {
245
                error!("[{}] get refresh token error: {}", FN_NAME, e);
246
                return Err(());
247
            }
248
60
            Ok(token) => match token {
249
8
                None => return Ok(None),
250
52
                Some(token) => token,
251
52
            },
252
52
        };
253
52
        if refresh.expires_at < Utc::now() {
254
            return Ok(None);
255
52
        }
256
52

            
257
52
        Ok(Some(Grant {
258
52
            owner_id: refresh.user_id,
259
52
            client_id: refresh.client_id,
260
52
            scope: match refresh.scope {
261
                None => "".parse().unwrap(),
262
52
                Some(scope) => match scope.as_str().parse() {
263
                    Err(e) => {
264
                        error!("[{}] parse access token scope error: {}", FN_NAME, e);
265
                        return Err(());
266
                    }
267
52
                    Ok(scope) => scope,
268
                },
269
            },
270
52
            redirect_uri: match refresh.redirect_uri.parse() {
271
                Err(e) => {
272
                    error!("[{}] parse access token redirect_uri error: {}", FN_NAME, e);
273
                    return Err(());
274
                }
275
52
                Ok(uri) => uri,
276
52
            },
277
52
            until: refresh.expires_at,
278
52
            extensions: Extensions::new(),
279
        }))
280
120
    }
281
}
282

            
283
#[async_trait]
284
impl Registrar for Primitive {
285
    async fn bound_redirect<'a>(
286
        &self,
287
        bound: ClientUrl<'a>,
288
612
    ) -> Result<BoundClient<'a>, RegistrarError> {
289
        const FN_NAME: &'static str = "bound_redirect";
290

            
291
612
        let cond = QueryCond {
292
612
            client_id: Some(&bound.client_id),
293
612
            ..Default::default()
294
612
        };
295
612
        let redirect_uris = match self.model.client().get(&cond).await {
296
            Err(e) => {
297
                error!("[{}] get client error: {}", FN_NAME, e);
298
                return Err(RegistrarError::PrimitiveError);
299
            }
300
612
            Ok(client) => match client {
301
8
                None => return Err(RegistrarError::Unspecified),
302
604
                Some(client) => client.redirect_uris,
303
            },
304
        };
305

            
306
604
        let redirect_uri = match bound.redirect_uri {
307
            None => match redirect_uris.len() {
308
                0 => return Err(RegistrarError::Unspecified),
309
                _ => redirect_uris.get(0).unwrap(),
310
            },
311
604
            Some(url) => match redirect_uris
312
604
                .iter()
313
604
                .find(|uri| uri.as_str() == url.as_str())
314
            {
315
                None => return Err(RegistrarError::Unspecified),
316
604
                Some(uri) => uri,
317
            },
318
        };
319
604
        let redirect_uri = match ExactUrl::new(redirect_uri.clone()) {
320
            Err(_) => return Err(RegistrarError::Unspecified),
321
604
            Ok(url) => url,
322
604
        };
323
604

            
324
604
        Ok(BoundClient {
325
604
            client_id: bound.client_id,
326
604
            redirect_uri: Cow::Owned(RegisteredUrl::Exact(redirect_uri)),
327
604
        })
328
1224
    }
329

            
330
    async fn negotiate<'a>(
331
        &self,
332
        bound: BoundClient<'a>,
333
        scope: Option<Scope>,
334
596
    ) -> Result<PreGrant, RegistrarError> {
335
        const FN_NAME: &'static str = "negotiate";
336

            
337
596
        let cond = QueryCond {
338
596
            client_id: Some(&bound.client_id),
339
596
            ..Default::default()
340
596
        };
341
596
        let client = match self.model.client().get(&cond).await {
342
            Err(e) => {
343
                return {
344
                    error!("[{}] get client error: {}", FN_NAME, e);
345
                    Err(RegistrarError::PrimitiveError)
346
                }
347
            }
348
596
            Ok(client) => match client {
349
                None => return Err(RegistrarError::Unspecified),
350
596
                Some(client) => client,
351
596
            },
352
596
        };
353
596

            
354
596
        if client.scopes.len() > 0 {
355
112
            match scope {
356
8
                None => return Err(RegistrarError::Unspecified),
357
104
                Some(scope) => {
358
104
                    let client_scopes = match client.scopes.join(" ").parse::<Scope>() {
359
                        Err(e) => {
360
                            error!("[{}] parse client scope error: {}", FN_NAME, e);
361
                            return Err(RegistrarError::PrimitiveError);
362
                        }
363
104
                        Ok(scopes) => scopes,
364
104
                    };
365
104
                    if !scope.allow_access(&client_scopes) {
366
8
                        return Err(RegistrarError::Unspecified);
367
96
                    }
368
                }
369
            }
370
484
        }
371

            
372
        Ok(PreGrant {
373
580
            client_id: bound.client_id.into_owned(),
374
580
            redirect_uri: bound.redirect_uri.into_owned(),
375
580
            scope: match client.scopes.join(" ").parse() {
376
                Err(e) => {
377
                    error!("[{}] parse client scope error: {}", FN_NAME, e);
378
                    return Err(RegistrarError::PrimitiveError);
379
                }
380
580
                Ok(scopes) => scopes,
381
            },
382
        })
383
1192
    }
384

            
385
    async fn check(
386
        &self,
387
        client_id: &str,
388
        passphrase: Option<&[u8]>,
389
576
    ) -> Result<(), RegistrarError> {
390
        const FN_NAME: &'static str = "check";
391

            
392
576
        let cond = QueryCond {
393
576
            client_id: Some(client_id),
394
576
            ..Default::default()
395
576
        };
396
576
        let client = match self.model.client().get(&cond).await {
397
            Err(e) => {
398
                error!("[{}] get client error: {}", FN_NAME, e);
399
                return Err(RegistrarError::PrimitiveError);
400
            }
401
576
            Ok(client) => match client {
402
8
                None => return Err(RegistrarError::Unspecified),
403
568
                Some(client) => client,
404
568
            },
405
568
        };
406
568

            
407
568
        match (passphrase, client.client_secret) {
408
476
            (None, None) => Ok(()),
409
64
            (Some(passphrase), Some(client_secret)) => {
410
64
                match passphrase == client_secret.as_bytes() {
411
64
                    true => Ok(()),
412
                    false => Err(RegistrarError::Unspecified),
413
                }
414
            }
415
28
            _ => Err(RegistrarError::Unspecified),
416
        }
417
1152
    }
418
}