1
use std::{borrow::Cow, sync::Arc};
2

            
3
use axum::{
4
    Extension,
5
    extract::State,
6
    http::{StatusCode, header},
7
    response::{IntoResponse, Response},
8
};
9
use chrono::{TimeDelta, Utc};
10
use log::{error, warn};
11
use oxide_auth::{
12
    code_grant::{
13
        accesstoken::{Authorization, Error as AccessTokenError, Request, TokenResponse},
14
        authorization::{Error as AuthorizationError, Request as OxideAuthorizationRequest},
15
        refresh::Error as RefreshTokenError,
16
    },
17
    primitives::{
18
        grant::{Extensions, Grant},
19
        scope::Scope,
20
    },
21
};
22
use oxide_auth_async::code_grant::{self, access_token::Endpoint as TokenEndpoint};
23
use serde_urlencoded;
24
use tera::{Context, Tera};
25
use url::Url;
26

            
27
use sylvia_iot_corelib::{constants::ContentType, err::E_UNKNOWN, http::Json, strings};
28

            
29
use super::{
30
    super::State as AppState,
31
    endpoint::Endpoint,
32
    request::{
33
        self, AccessTokenRequest, AuthorizationRequest, GetAuthRequest, GetLoginRequest,
34
        PostLoginRequest, RefreshTokenRequest,
35
    },
36
    response::OAuth2Error,
37
};
38
use crate::models::{
39
    Model,
40
    client::QueryCond as ClientQueryCond,
41
    login_session::{self, LoginSession, QueryCond as SessionQueryCond},
42
    user::QueryCond as UserQueryCond,
43
};
44

            
45
pub const TMPL_LOGIN: &'static str = "login";
46
pub const TMPL_GRANT: &'static str = "grant";
47

            
48
/// `GET /{base}/oauth2/auth`
49
///
50
/// Authenticate client and redirect to the login page.
51
32
pub async fn get_auth(State(state): State<AppState>, req: GetAuthRequest) -> impl IntoResponse {
52
    const FN_NAME: &'static str = "get_auth";
53

            
54
32
    if let Err(resp) = check_auth_params(FN_NAME, &req, &state.model).await {
55
24
        return resp;
56
8
    }
57

            
58
8
    let login_state: String = match serde_urlencoded::to_string(&req) {
59
        Err(e) => {
60
            let err_str = e.to_string();
61
            error!(
62
                "[{}] encode authorize state error: {}",
63
                FN_NAME,
64
                err_str.as_str()
65
            );
66
            return redirect_server_error(
67
                FN_NAME,
68
                req.redirect_uri.as_str(),
69
                Some(err_str.as_str()),
70
            );
71
        }
72
8
        Ok(str) => match serde_urlencoded::to_string(GetLoginRequest { state: str }) {
73
            Err(e) => {
74
                let err_str = e.to_string();
75
                error!(
76
                    "[{}] encode login state error: {}",
77
                    FN_NAME,
78
                    err_str.as_str()
79
                );
80
                return redirect_server_error(
81
                    FN_NAME,
82
                    req.redirect_uri.as_str(),
83
                    Some(err_str.as_str()),
84
                );
85
            }
86
8
            Ok(str) => str,
87
        },
88
    };
89
8
    resp_found(format!("{}/oauth2/login?{}", state.scope_path, login_state))
90
32
}
91

            
92
/// `GET /{base}/oauth2/login`
93
///
94
/// To render the login page.
95
24
pub async fn get_login(
96
24
    State(state): State<AppState>,
97
24
    Extension(tera): Extension<Tera>,
98
24
    req: GetLoginRequest,
99
24
) -> impl IntoResponse {
100
    const FN_NAME: &'static str = "get_login";
101

            
102
24
    if req.state.as_str().len() == 0 {
103
4
        warn!("[{}] empty state content", FN_NAME);
104
4
        return resp_invalid_request(Some("invalid state content"));
105
20
    }
106
20
    match serde_urlencoded::from_str::<GetAuthRequest>(req.state.as_str()) {
107
4
        Err(e) => {
108
4
            warn!(
109
                "[{}] parse state error: {}, content: {}",
110
                FN_NAME,
111
                e,
112
                req.state.as_str()
113
            );
114
4
            return resp_invalid_request(Some("invalid state content"));
115
        }
116
16
        Ok(inner_req) => {
117
16
            if let Err(resp) = check_auth_params(FN_NAME, &inner_req, &state.model).await {
118
8
                return resp;
119
8
            }
120
        }
121
    }
122

            
123
8
    let mut context = Context::new();
124
8
    context.insert("scope_path", &state.scope_path);
125
8
    context.insert("state", &req.state);
126
8
    let page = match tera.render(TMPL_LOGIN, &context) {
127
        Err(e) => {
128
            let err_str = e.to_string();
129
            error!(
130
                "[{}] render login template error: {}",
131
                FN_NAME,
132
                err_str.as_str()
133
            );
134
            return resp_temporary_unavailable(Some(err_str));
135
        }
136
8
        Ok(page) => page,
137
    };
138

            
139
8
    ([(header::CONTENT_TYPE, "text/html; charset=utf-8")], page).into_response()
140
24
}
141

            
142
/// `POST /{base}/oauth2/login`
143
///
144
/// Do the login process.
145
592
pub async fn post_login(State(state): State<AppState>, req: PostLoginRequest) -> impl IntoResponse {
146
    const FN_NAME: &'static str = "post_login";
147

            
148
592
    if req.state.as_str().len() == 0 {
149
4
        warn!("[{}] empty state content", FN_NAME);
150
4
        return resp_invalid_request(Some("invalid state content"));
151
588
    }
152
588
    match serde_urlencoded::from_str::<GetAuthRequest>(req.state.as_str()) {
153
4
        Err(e) => {
154
4
            warn!(
155
                "[{}] parse state error: {}, content: {}",
156
                FN_NAME,
157
                e,
158
                req.state.as_str()
159
            );
160
4
            return resp_invalid_request(Some("invalid state content"));
161
        }
162
584
        Ok(inner_req) => {
163
584
            if let Err(resp) = check_auth_params(FN_NAME, &inner_req, &state.model).await {
164
8
                return resp;
165
576
            }
166
        }
167
    }
168

            
169
576
    let user_cond = UserQueryCond {
170
576
        user_id: None,
171
576
        account: Some(req.account.as_str()),
172
576
    };
173
576
    let user_id = match state.model.user().get(&user_cond).await {
174
        Err(e) => {
175
            let err_str = e.to_string();
176
            error!("[{}] get user DB error: {}", FN_NAME, err_str.as_str());
177
            return resp_temporary_unavailable(Some(err_str));
178
        }
179
576
        Ok(user) => match user {
180
            None => {
181
4
                return resp_invalid_auth(None);
182
            }
183
572
            Some(user) => {
184
572
                if user.disabled_at.is_some() {
185
                    return resp_invalid_auth(None);
186
572
                } else if let Some(expired_at) = user.expired_at
187
8
                    && Utc::now() >= expired_at
188
                {
189
                    return resp_invalid_auth(None);
190
572
                }
191
572
                let hash = strings::password_hash(req.password.as_str(), user.salt.as_str());
192
572
                if user.password != hash {
193
4
                    return resp_invalid_auth(None);
194
568
                }
195
568
                user.user_id
196
            }
197
        },
198
    };
199

            
200
568
    let session = LoginSession {
201
568
        session_id: strings::random_id_sha(&Utc::now(), 4),
202
568
        expires_at: match TimeDelta::try_seconds(login_session::EXPIRES) {
203
            None => panic!("{}", E_UNKNOWN),
204
568
            Some(t) => Utc::now() + t,
205
        },
206
568
        user_id,
207
    };
208
568
    if let Err(e) = state.model.login_session().add(&session).await {
209
        let err_str = e.to_string();
210
        error!("[{}] add login session error: {}", FN_NAME, e);
211
        return resp_temporary_unavailable(Some(err_str));
212
568
    }
213
568
    resp_found(format!(
214
        "{}/oauth2/authorize?{}&session_id={}",
215
        state.scope_path, req.state, session.session_id
216
    ))
217
592
}
218

            
219
/// `GET /{base}/oauth2/authorize` and `POST /{base}/oauth2/authorize`
220
///
221
/// To render the OAuth2 grant page or to authorize the client and grant.
222
620
pub async fn authorize(
223
620
    State(state): State<AppState>,
224
620
    Extension(tera): Extension<Tera>,
225
620
    req: AuthorizationRequest,
226
620
) -> impl IntoResponse {
227
    const FN_NAME: &'static str = "authorize";
228

            
229
620
    let mut endpoint = Endpoint::new(state.model.clone(), None);
230
620
    let pending = match code_grant::authorization::authorization_code(&mut endpoint, &req).await {
231
40
        Err(e) => match e {
232
            AuthorizationError::Ignore => {
233
16
                return resp_invalid_request(None);
234
            }
235
24
            AuthorizationError::Redirect(url) => {
236
24
                let url: Url = url.into();
237
24
                return resp_found(url.to_string());
238
            }
239
            AuthorizationError::PrimitiveError => {
240
                error!("[{}] authorize() with primitive error", FN_NAME);
241
                return resp_temporary_unavailable(None);
242
            }
243
        },
244
580
        Ok(pending) => pending,
245
    };
246

            
247
580
    if let Some(allowed) = req.allowed() {
248
560
        match allowed {
249
            false => {
250
8
                if let Err(e) = pending.deny() {
251
8
                    match e {
252
8
                        AuthorizationError::Redirect(url) => {
253
8
                            let url: Url = url.into();
254
8
                            return resp_found(url.to_string());
255
                        }
256
                        _ => (),
257
                    }
258
                }
259
                let e = OAuth2Error::new("server_error", Some("deny error".to_string()));
260
                return (StatusCode::INTERNAL_SERVER_ERROR, Json(e)).into_response();
261
            }
262
            true => {
263
552
                let session_id = req.session_id();
264
552
                let user_id = match state.model.login_session().get(session_id).await {
265
                    Err(e) => {
266
                        error!("[{}] authorize() with primitive error: {}", FN_NAME, e);
267
                        return resp_temporary_unavailable(None);
268
                    }
269
552
                    Ok(session) => match session {
270
                        None => {
271
4
                            warn!("[{}] authorize() with invalid session ID", FN_NAME);
272
4
                            return resp_invalid_auth(None);
273
                        }
274
548
                        Some(session) => session.user_id,
275
                    },
276
                };
277
548
                let cond = SessionQueryCond {
278
548
                    session_id: Some(session_id),
279
548
                    ..Default::default()
280
548
                };
281
548
                if let Err(e) = state.model.login_session().del(&cond).await {
282
                    error!("[{}] authorize() remove session ID error: {}", FN_NAME, e);
283
                    return resp_temporary_unavailable(None);
284
548
                }
285
548
                match pending.authorize(&mut endpoint, Cow::from(user_id)).await {
286
                    Err(_) => {
287
                        error!("[{}] authorize error", FN_NAME);
288
                        return resp_temporary_unavailable(None);
289
                    }
290
548
                    Ok(url) => {
291
548
                        return resp_found(url.to_string());
292
                    }
293
                }
294
            }
295
        }
296
20
    }
297

            
298
20
    let client_id = req.client_id().unwrap();
299
20
    let client_cond = ClientQueryCond {
300
20
        user_id: None,
301
20
        client_id: Some(client_id.as_ref()),
302
20
    };
303
20
    let client_name = match state.model.client().get(&client_cond).await {
304
        Err(e) => {
305
            let err_str = e.to_string();
306
            error!("[{}] get client DB error: {}", FN_NAME, err_str.as_str());
307
            return resp_temporary_unavailable(Some(err_str));
308
        }
309
20
        Ok(client) => match client {
310
            None => {
311
                return resp_invalid_request(Some("invalid client"));
312
            }
313
20
            Some(client) => client.name,
314
        },
315
    };
316

            
317
20
    let mut context = Context::new();
318
20
    context.insert("scope_path", &state.scope_path);
319
20
    context.insert("client_name", &client_name);
320
20
    context.insert("session_id", req.session_id());
321
20
    context.insert("client_id", client_id.as_ref());
322
20
    context.insert("response_type", req.response_type().unwrap().as_ref());
323
20
    context.insert("redirect_uri", req.redirect_uri().unwrap().as_ref());
324
20
    context.insert("allow_value", request::ALLOW_VALUE);
325
20
    let scope = req.scope();
326
20
    if let Some(ref scope) = scope {
327
12
        context.insert("scope", scope);
328
12
    }
329
20
    let state = req.state();
330
20
    if let Some(ref state) = state {
331
4
        context.insert("state", state);
332
16
    }
333
20
    let page = match tera.render(TMPL_GRANT, &context) {
334
        Err(e) => {
335
            let err_str = e.to_string();
336
            error!("[{}] get client DB error: {}", FN_NAME, err_str.as_str());
337
            return resp_temporary_unavailable(Some(err_str));
338
        }
339
20
        Ok(page) => page,
340
    };
341
20
    ([(header::CONTENT_TYPE, "text/html; charset=utf-8")], page).into_response()
342
620
}
343

            
344
/// `POST /{base}/oauth2/token`
345
///
346
/// To generate an access token with the authorization code or client credentials.
347
580
pub async fn post_token(
348
580
    State(state): State<AppState>,
349
580
    req: AccessTokenRequest,
350
580
) -> impl IntoResponse {
351
580
    let mut endpoint = Endpoint::new(state.model.clone(), None);
352

            
353
580
    if let Some(grant_type) = req.grant_type() {
354
580
        if grant_type.eq("client_credentials") {
355
32
            return client_credentials_token(&req, &state, &mut endpoint).await;
356
548
        }
357
    }
358

            
359
548
    let token = match code_grant::access_token::access_token(&mut endpoint, &req).await {
360
52
        Err(e) => match e {
361
36
            AccessTokenError::Invalid(desc) => {
362
36
                return (
363
36
                    StatusCode::BAD_REQUEST,
364
36
                    [(header::CONTENT_TYPE, ContentType::JSON)],
365
36
                    desc.to_json(),
366
36
                )
367
36
                    .into_response();
368
            }
369
16
            AccessTokenError::Unauthorized(desc, authtype) => {
370
16
                return (
371
16
                    StatusCode::UNAUTHORIZED,
372
16
                    [
373
16
                        (header::CONTENT_TYPE, ContentType::JSON),
374
16
                        (header::WWW_AUTHENTICATE, authtype.as_str()),
375
16
                    ],
376
16
                    desc.to_json(),
377
16
                )
378
16
                    .into_response();
379
            }
380
            // TODO: handle this
381
            AccessTokenError::Primitive(_e) => {
382
                return StatusCode::SERVICE_UNAVAILABLE.into_response();
383
            }
384
        },
385
496
        Ok(token) => token,
386
    };
387
496
    ([(header::CONTENT_TYPE, ContentType::JSON)], token.to_json()).into_response()
388
580
}
389

            
390
/// `POST /{base}/oauth2/refresh`
391
///
392
/// To refresh an access token.
393
68
pub async fn post_refresh(
394
68
    State(state): State<AppState>,
395
68
    req: RefreshTokenRequest,
396
68
) -> impl IntoResponse {
397
68
    let mut endpoint = Endpoint::new(state.model.clone(), None);
398
68
    let token = match code_grant::refresh::refresh(&mut endpoint, &req).await {
399
44
        Err(e) => match e {
400
24
            RefreshTokenError::Invalid(desc) => {
401
24
                return (
402
24
                    StatusCode::BAD_REQUEST,
403
24
                    [(header::CONTENT_TYPE, ContentType::JSON)],
404
24
                    desc.to_json(),
405
24
                )
406
24
                    .into_response();
407
            }
408
20
            RefreshTokenError::Unauthorized(desc, authtype) => {
409
20
                return (
410
20
                    StatusCode::UNAUTHORIZED,
411
20
                    [
412
20
                        (header::CONTENT_TYPE, ContentType::JSON),
413
20
                        (header::WWW_AUTHENTICATE, authtype.as_str()),
414
20
                    ],
415
20
                    desc.to_json(),
416
20
                )
417
20
                    .into_response();
418
            }
419
            RefreshTokenError::Primitive => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
420
        },
421
24
        Ok(token) => token,
422
    };
423
24
    ([(header::CONTENT_TYPE, ContentType::JSON)], token.to_json()).into_response()
424
68
}
425

            
426
32
async fn client_credentials_token(
427
32
    req: &AccessTokenRequest,
428
32
    state: &AppState,
429
32
    endpoint: &mut Endpoint,
430
32
) -> Response {
431
    // Validate the client.
432
32
    let (client_id, client_secret) = match req.authorization() {
433
4
        Authorization::None => return resp_invalid_request(None),
434
12
        Authorization::Username(_) => return resp_invalid_client(None),
435
16
        Authorization::UsernamePassword(user, pass) => (user, pass),
436
        _ => return resp_invalid_request(None),
437
    };
438
16
    let cond = ClientQueryCond {
439
16
        client_id: Some(client_id.as_ref()),
440
16
        ..Default::default()
441
16
    };
442
16
    let client = match state.model.client().get(&cond).await {
443
        Err(e) => return resp_temporary_unavailable(Some(format!("{}", e))),
444
16
        Ok(client) => match client {
445
            None => return resp_invalid_client(None),
446
16
            Some(client) => client,
447
        },
448
    };
449
16
    match client.client_secret.as_ref() {
450
        None => return resp_invalid_client(None),
451
16
        Some(secret) => match secret.as_bytes().eq(client_secret.as_ref()) {
452
4
            false => return resp_invalid_client(None),
453
12
            true => (),
454
        },
455
    }
456

            
457
    // Reuse the issuer to generate tokens.
458
4
    let grant = Grant {
459
12
        owner_id: client.user_id,
460
12
        client_id: client.client_id,
461
12
        scope: match client.scopes.as_slice().join(" ").parse() {
462
            Err(_) => return resp_invalid_client(Some("no valid scope")),
463
12
            Ok(scope) => scope,
464
        },
465
12
        redirect_uri: match client.redirect_uris.get(0) {
466
4
            None => return resp_invalid_client(Some("no valid redirect_uri")),
467
8
            Some(uri) => match Url::parse(uri.as_str()) {
468
4
                Err(_) => return resp_invalid_client(Some("invalid redirect_uri")),
469
4
                Ok(uri) => uri,
470
            },
471
        },
472
4
        until: match TimeDelta::try_minutes(10) {
473
            None => panic!("{}", E_UNKNOWN),
474
4
            Some(t) => Utc::now() + t,
475
        },
476
4
        extensions: Extensions::new(),
477
    };
478
4
    let token = match endpoint.issuer().issue(grant).await {
479
        Err(_) => return resp_temporary_unavailable(None),
480
4
        Ok(token) => token,
481
    };
482

            
483
4
    Json(TokenResponse {
484
4
        access_token: Some(token.token),
485
4
        refresh_token: token.refresh,
486
4
        token_type: Some("bearer".to_string()),
487
4
        expires_in: Some(token.until.signed_duration_since(Utc::now()).num_seconds()),
488
4
        scope: Some(client.scopes.as_slice().join(" ")),
489
4
        error: None,
490
4
    })
491
4
    .into_response()
492
32
}
493

            
494
/// To check the authorization grant flow parameters.
495
632
async fn check_auth_params(
496
632
    fn_name: &str,
497
632
    req: &GetAuthRequest,
498
632
    model: &Arc<dyn Model>,
499
632
) -> Result<(), Response> {
500
632
    if req.response_type != "code" {
501
12
        return Err(resp_invalid_request(Some("unsupport response_type")));
502
620
    }
503

            
504
620
    let client_cond = ClientQueryCond {
505
620
        user_id: None,
506
620
        client_id: Some(req.client_id.as_str()),
507
620
    };
508
620
    match model.client().get(&client_cond).await {
509
        Err(e) => {
510
            let err_str = e.to_string();
511
            error!("[{}] get client DB error: {}", fn_name, err_str.as_str());
512
            return Err(resp_temporary_unavailable(Some(err_str)));
513
        }
514
620
        Ok(client) => match client {
515
            None => {
516
4
                return Err(resp_invalid_request(Some("invalid client")));
517
            }
518
616
            Some(client) => {
519
616
                if !client.redirect_uris.contains(&req.redirect_uri) {
520
4
                    return Err(resp_invalid_request(Some("invalid redirect_uri")));
521
612
                } else if client.scopes.len() > 0 {
522
36
                    if req.scope.is_none() {
523
12
                        return Err(redirect_invalid_scope(&req.redirect_uri));
524
24
                    }
525
24
                    let req_scopes = match req.scope.as_ref().unwrap().parse::<Scope>() {
526
4
                        Err(_e) => {
527
                            // TODO: handle this with the error reason.
528
4
                            return Err(redirect_invalid_scope(&req.redirect_uri));
529
                        }
530
20
                        Ok(scopes) => scopes,
531
                    };
532
20
                    let client_scopes = match client.scopes.join(" ").parse::<Scope>() {
533
                        Err(e) => {
534
                            error!("[{}] parse client scopes error: {}", fn_name, e);
535
                            return Err(redirect_server_error(fn_name, &req.redirect_uri, None));
536
                        }
537
20
                        Ok(scopes) => scopes,
538
                    };
539
20
                    if !req_scopes.allow_access(&client_scopes) {
540
4
                        return Err(redirect_invalid_scope(&req.redirect_uri));
541
16
                    }
542
576
                }
543
            }
544
        },
545
    }
546
592
    Ok(())
547
632
}
548

            
549
20
fn redirect_invalid_scope(redirect_uri: &str) -> Response {
550
20
    resp_found(format!("{}?error=invalid_scope", redirect_uri))
551
20
}
552

            
553
fn redirect_server_error(fn_name: &str, redirect_uri: &str, description: Option<&str>) -> Response {
554
    let location = match description {
555
        None => format!("{}?error=server_error", redirect_uri),
556
        Some(desc) => {
557
            let err_desc = [("error_description", desc)];
558
            match serde_urlencoded::to_string(&err_desc) {
559
                Err(e) => {
560
                    error!("[{}] redirect server_error encode error: {}", fn_name, e);
561
                    format!("{}?error=server_error", redirect_uri)
562
                }
563
                Ok(qs) => format!("{}?error=server_error&{}", redirect_uri, qs),
564
            }
565
        }
566
    };
567
    resp_found(location)
568
}
569

            
570
1176
fn resp_found(location: String) -> Response {
571
1176
    (StatusCode::FOUND, [(header::LOCATION, location)]).into_response()
572
1176
}
573

            
574
12
fn resp_invalid_auth<'a>(description: Option<&'a str>) -> Response {
575
12
    let description = match description {
576
12
        None => None,
577
        Some(description) => Some(description.to_string()),
578
    };
579
12
    (
580
12
        StatusCode::BAD_REQUEST,
581
12
        Json(OAuth2Error::new("invalid_auth", description)),
582
12
    )
583
12
        .into_response()
584
12
}
585

            
586
24
fn resp_invalid_client<'a>(description: Option<&'a str>) -> Response {
587
24
    let description = match description {
588
16
        None => None,
589
8
        Some(description) => Some(description.to_string()),
590
    };
591
24
    (
592
24
        StatusCode::UNAUTHORIZED,
593
24
        Json(OAuth2Error::new("invalid_client", description)),
594
24
    )
595
24
        .into_response()
596
24
}
597

            
598
56
fn resp_invalid_request<'a>(description: Option<&'a str>) -> Response {
599
56
    let description = match description {
600
20
        None => None,
601
36
        Some(description) => Some(description.to_string()),
602
    };
603
56
    (
604
56
        StatusCode::BAD_REQUEST,
605
56
        Json(OAuth2Error::new("invalid_request", description)),
606
56
    )
607
56
        .into_response()
608
56
}
609

            
610
fn resp_temporary_unavailable(description: Option<String>) -> Response {
611
    (
612
        StatusCode::SERVICE_UNAVAILABLE,
613
        Json(OAuth2Error::new("temporarily_unavailable", description)),
614
    )
615
        .into_response()
616
}