1
use std::{borrow::Cow, sync::Arc};
2

            
3
use async_trait::async_trait;
4
use chrono::{TimeDelta, Utc};
5
use log::error;
6
use oxide_auth::primitives::{
7
    grant::{Extensions, Grant},
8
    issuer::{IssuedToken, RefreshedToken, TokenType},
9
    registrar::{BoundClient, ClientUrl, ExactUrl, PreGrant, RegisteredUrl, RegistrarError},
10
    scope::Scope,
11
};
12
use oxide_auth_async::primitives::{Authorizer, Issuer, Registrar};
13

            
14
use sylvia_iot_corelib::{err::E_UNKNOWN, strings};
15

            
16
use crate::models::{
17
    Model,
18
    access_token::{self, AccessToken, QueryCond as AccessTokenQuery},
19
    authorization_code::{self, AuthorizationCode, QueryCond as AuthorizationCodeQuery},
20
    client::QueryCond,
21
    refresh_token::{self, QueryCond as RefreshTokenQuery, RefreshToken},
22
};
23

            
24
#[derive(Clone)]
25
pub struct Primitive {
26
    model: Arc<dyn Model>,
27
}
28

            
29
impl Primitive {
30
55836
    pub fn new(model: Arc<dyn Model>) -> Self {
31
55836
        Primitive {
32
55836
            model: model.clone(),
33
55836
        }
34
55836
    }
35
}
36

            
37
#[async_trait]
38
impl Authorizer for Primitive {
39
548
    async fn authorize(&mut self, grant: Grant) -> Result<String, ()> {
40
        const FN_NAME: &'static str = "authorize";
41

            
42
        let scope = grant.scope.to_string();
43
        let code = AuthorizationCode {
44
            code: strings::random_id_sha(&grant.until, 4),
45
            expires_at: match TimeDelta::try_seconds(authorization_code::EXPIRES) {
46
                None => panic!("{}", E_UNKNOWN),
47
                Some(t) => Utc::now() + t,
48
            },
49
            redirect_uri: grant.redirect_uri.to_string(),
50
            scope: match scope.len() {
51
                0 => None,
52
                _ => Some(scope),
53
            },
54
            client_id: grant.client_id,
55
            user_id: grant.owner_id,
56
        };
57

            
58
        match self.model.authorization_code().add(&code).await {
59
            Err(e) => {
60
                error!("[{}] add authorization code error: {}", FN_NAME, e);
61
                Err(())
62
            }
63
            Ok(()) => Ok(code.code),
64
        }
65
548
    }
66

            
67
508
    async fn extract(&mut self, code: &str) -> Result<Option<Grant>, ()> {
68
        const FN_NAME: &'static str = "extract";
69

            
70
        let auth_code = match self.model.authorization_code().get(code).await {
71
            Err(_) => return Err(()),
72
            Ok(code) => match code {
73
                None => return Ok(None),
74
                Some(code) => code,
75
            },
76
        };
77
        {
78
            let query = AuthorizationCodeQuery {
79
                code: Some(code),
80
                ..Default::default()
81
            };
82
            if let Err(e) = self.model.authorization_code().del(&query).await {
83
                error!("[{}] delete authorization code error: {}", FN_NAME, e);
84
                return Err(());
85
            }
86
        }
87
        if auth_code.expires_at < Utc::now() {
88
            return Ok(None);
89
        }
90

            
91
        Ok(Some(Grant {
92
            owner_id: auth_code.user_id,
93
            client_id: auth_code.client_id,
94
            scope: match auth_code.scope {
95
                None => "".parse().unwrap(),
96
                Some(scope) => match scope.as_str().parse() {
97
                    Err(e) => {
98
                        error!("[{}] parse authorization code scope error: {}", FN_NAME, e);
99
                        return Err(());
100
                    }
101
                    Ok(scope) => scope,
102
                },
103
            },
104
            redirect_uri: match auth_code.redirect_uri.parse() {
105
                Err(e) => {
106
                    error!(
107
                        "[{}] parse authorization code redirect_uri error: {}",
108
                        FN_NAME, e
109
                    );
110
                    return Err(());
111
                }
112
                Ok(uri) => uri,
113
            },
114
            until: auth_code.expires_at,
115
            extensions: Extensions::new(),
116
        }))
117
508
    }
118
}
119

            
120
#[async_trait]
121
impl Issuer for Primitive {
122
524
    async fn issue(&mut self, grant: Grant) -> Result<IssuedToken, ()> {
123
        const FN_NAME: &'static str = "issue";
124

            
125
        let now = Utc::now();
126
        let refresh_token = strings::random_id_sha(&now, 8);
127
        let access = AccessToken {
128
            access_token: strings::random_id_sha(&now, 8),
129
            refresh_token: Some(refresh_token.clone()),
130
            expires_at: match TimeDelta::try_seconds(access_token::EXPIRES) {
131
                None => panic!("{}", E_UNKNOWN),
132
                Some(t) => now + t,
133
            },
134
            scope: Some(grant.scope.to_string()),
135
            client_id: grant.client_id.clone(),
136
            redirect_uri: grant.redirect_uri.to_string(),
137
            user_id: grant.owner_id.clone(),
138
        };
139
        let refresh = RefreshToken {
140
            refresh_token,
141
            expires_at: match TimeDelta::try_seconds(refresh_token::EXPIRES) {
142
                None => panic!("{}", E_UNKNOWN),
143
                Some(t) => now + t,
144
            },
145
            scope: Some(grant.scope.to_string()),
146
            client_id: grant.client_id,
147
            redirect_uri: grant.redirect_uri.to_string(),
148
            user_id: grant.owner_id,
149
        };
150

            
151
        if let Err(e) = self.model.access_token().add(&access).await {
152
            error!("[{}] add access token error: {}", FN_NAME, e);
153
            return Err(());
154
        }
155
        if let Err(e) = self.model.refresh_token().add(&refresh).await {
156
            error!("[{}] add refresh token error: {}", FN_NAME, e);
157
            return Err(());
158
        }
159
        Ok(IssuedToken {
160
            token: access.access_token,
161
            refresh: Some(refresh.refresh_token),
162
            until: access.expires_at,
163
            token_type: TokenType::Bearer,
164
        })
165
524
    }
166

            
167
24
    async fn refresh(&mut self, token: &str, grant: Grant) -> Result<RefreshedToken, ()> {
168
        const FN_NAME: &'static str = "refresh";
169

            
170
        let query = AccessTokenQuery {
171
            refresh_token: Some(token),
172
            ..Default::default()
173
        };
174
        if let Err(e) = self.model.access_token().del(&query).await {
175
            error!("[{}] delete access token error: {}", FN_NAME, e);
176
            return Err(());
177
        }
178
        let query = RefreshTokenQuery {
179
            refresh_token: Some(token),
180
            ..Default::default()
181
        };
182
        if let Err(e) = self.model.refresh_token().del(&query).await {
183
            error!("[{}] delete refresh token error: {}", FN_NAME, e);
184
            return Err(());
185
        }
186

            
187
        match self.issue(grant).await {
188
            Err(_) => Err(()),
189
            Ok(token) => Ok(RefreshedToken {
190
                token: token.token,
191
                refresh: token.refresh,
192
                until: token.until,
193
                token_type: token.token_type,
194
            }),
195
        }
196
24
    }
197

            
198
996
    async fn recover_token(&mut self, token: &str) -> Result<Option<Grant>, ()> {
199
        const FN_NAME: &'static str = "recover_token";
200

            
201
        let access = match self.model.access_token().get(token).await {
202
            Err(e) => {
203
                error!("[{}] get access token error: {}", FN_NAME, e);
204
                return Err(());
205
            }
206
            Ok(token) => match token {
207
                None => return Ok(None),
208
                Some(token) => token,
209
            },
210
        };
211
        if access.expires_at < Utc::now() {
212
            return Ok(None);
213
        }
214

            
215
        Ok(Some(Grant {
216
            owner_id: access.user_id,
217
            client_id: access.client_id,
218
            scope: match access.scope {
219
                None => "".parse().unwrap(),
220
                Some(scope) => match scope.as_str().parse() {
221
                    Err(e) => {
222
                        error!("[{}] parse access token scope error: {}", FN_NAME, e);
223
                        return Err(());
224
                    }
225
                    Ok(scope) => scope,
226
                },
227
            },
228
            redirect_uri: match access.redirect_uri.parse() {
229
                Err(e) => {
230
                    error!("[{}] parse access token redirect_uri error: {}", FN_NAME, e);
231
                    return Err(());
232
                }
233
                Ok(uri) => uri,
234
            },
235
            until: access.expires_at,
236
            extensions: Extensions::new(),
237
        }))
238
996
    }
239

            
240
60
    async fn recover_refresh(&mut self, token: &str) -> Result<Option<Grant>, ()> {
241
        const FN_NAME: &'static str = "recover_refresh";
242

            
243
        let refresh = match self.model.refresh_token().get(token).await {
244
            Err(e) => {
245
                error!("[{}] get refresh token error: {}", FN_NAME, e);
246
                return Err(());
247
            }
248
            Ok(token) => match token {
249
                None => return Ok(None),
250
                Some(token) => token,
251
            },
252
        };
253
        if refresh.expires_at < Utc::now() {
254
            return Ok(None);
255
        }
256

            
257
        Ok(Some(Grant {
258
            owner_id: refresh.user_id,
259
            client_id: refresh.client_id,
260
            scope: match refresh.scope {
261
                None => "".parse().unwrap(),
262
                Some(scope) => match scope.as_str().parse() {
263
                    Err(e) => {
264
                        error!("[{}] parse access token scope error: {}", FN_NAME, e);
265
                        return Err(());
266
                    }
267
                    Ok(scope) => scope,
268
                },
269
            },
270
            redirect_uri: match refresh.redirect_uri.parse() {
271
                Err(e) => {
272
                    error!("[{}] parse access token redirect_uri error: {}", FN_NAME, e);
273
                    return Err(());
274
                }
275
                Ok(uri) => uri,
276
            },
277
            until: refresh.expires_at,
278
            extensions: Extensions::new(),
279
        }))
280
60
    }
281
}
282

            
283
#[async_trait]
284
impl Registrar for Primitive {
285
    async fn bound_redirect<'a>(
286
        &self,
287
        bound: ClientUrl<'a>,
288
612
    ) -> Result<BoundClient<'a>, RegistrarError> {
289
        const FN_NAME: &'static str = "bound_redirect";
290

            
291
        let cond = QueryCond {
292
            client_id: Some(&bound.client_id),
293
            ..Default::default()
294
        };
295
        let redirect_uris = match self.model.client().get(&cond).await {
296
            Err(e) => {
297
                error!("[{}] get client error: {}", FN_NAME, e);
298
                return Err(RegistrarError::PrimitiveError);
299
            }
300
            Ok(client) => match client {
301
                None => return Err(RegistrarError::Unspecified),
302
                Some(client) => client.redirect_uris,
303
            },
304
        };
305

            
306
        let redirect_uri = match bound.redirect_uri {
307
            None => match redirect_uris.len() {
308
                0 => return Err(RegistrarError::Unspecified),
309
                _ => redirect_uris.get(0).unwrap(),
310
            },
311
            Some(url) => match redirect_uris
312
                .iter()
313
604
                .find(|uri| uri.as_str() == url.as_str())
314
            {
315
                None => return Err(RegistrarError::Unspecified),
316
                Some(uri) => uri,
317
            },
318
        };
319
        let redirect_uri = match ExactUrl::new(redirect_uri.clone()) {
320
            Err(_) => return Err(RegistrarError::Unspecified),
321
            Ok(url) => url,
322
        };
323

            
324
        Ok(BoundClient {
325
            client_id: bound.client_id,
326
            redirect_uri: Cow::Owned(RegisteredUrl::Exact(redirect_uri)),
327
        })
328
612
    }
329

            
330
    async fn negotiate<'a>(
331
        &self,
332
        bound: BoundClient<'a>,
333
        scope: Option<Scope>,
334
596
    ) -> Result<PreGrant, RegistrarError> {
335
        const FN_NAME: &'static str = "negotiate";
336

            
337
        let cond = QueryCond {
338
            client_id: Some(&bound.client_id),
339
            ..Default::default()
340
        };
341
        let client = match self.model.client().get(&cond).await {
342
            Err(e) => {
343
                return {
344
                    error!("[{}] get client error: {}", FN_NAME, e);
345
                    Err(RegistrarError::PrimitiveError)
346
                };
347
            }
348
            Ok(client) => match client {
349
                None => return Err(RegistrarError::Unspecified),
350
                Some(client) => client,
351
            },
352
        };
353

            
354
        if client.scopes.len() > 0 {
355
            match scope {
356
                None => return Err(RegistrarError::Unspecified),
357
                Some(scope) => {
358
                    let client_scopes = match client.scopes.join(" ").parse::<Scope>() {
359
                        Err(e) => {
360
                            error!("[{}] parse client scope error: {}", FN_NAME, e);
361
                            return Err(RegistrarError::PrimitiveError);
362
                        }
363
                        Ok(scopes) => scopes,
364
                    };
365
                    if !scope.allow_access(&client_scopes) {
366
                        return Err(RegistrarError::Unspecified);
367
                    }
368
                }
369
            }
370
        }
371

            
372
        Ok(PreGrant {
373
            client_id: bound.client_id.into_owned(),
374
            redirect_uri: bound.redirect_uri.into_owned(),
375
            scope: match client.scopes.join(" ").parse() {
376
                Err(e) => {
377
                    error!("[{}] parse client scope error: {}", FN_NAME, e);
378
                    return Err(RegistrarError::PrimitiveError);
379
                }
380
                Ok(scopes) => scopes,
381
            },
382
        })
383
596
    }
384

            
385
    async fn check(
386
        &self,
387
        client_id: &str,
388
        passphrase: Option<&[u8]>,
389
576
    ) -> Result<(), RegistrarError> {
390
        const FN_NAME: &'static str = "check";
391

            
392
        let cond = QueryCond {
393
            client_id: Some(client_id),
394
            ..Default::default()
395
        };
396
        let client = match self.model.client().get(&cond).await {
397
            Err(e) => {
398
                error!("[{}] get client error: {}", FN_NAME, e);
399
                return Err(RegistrarError::PrimitiveError);
400
            }
401
            Ok(client) => match client {
402
                None => return Err(RegistrarError::Unspecified),
403
                Some(client) => client,
404
            },
405
        };
406

            
407
        match (passphrase, client.client_secret) {
408
            (None, None) => Ok(()),
409
            (Some(passphrase), Some(client_secret)) => {
410
                match passphrase == client_secret.as_bytes() {
411
                    true => Ok(()),
412
                    false => Err(RegistrarError::Unspecified),
413
                }
414
            }
415
            _ => Err(RegistrarError::Unspecified),
416
        }
417
576
    }
418
}