1
use std::{borrow::Cow, sync::Arc};
2

            
3
use axum::{
4
    extract::State,
5
    http::{header, StatusCode},
6
    response::{IntoResponse, Response},
7
    Extension,
8
};
9
use chrono::{TimeDelta, Utc};
10
use log::{error, warn};
11
use oxide_auth::{
12
    code_grant::{
13
        accesstoken::{Authorization, Error as AccessTokenError, Request, TokenResponse},
14
        authorization::{Error as AuthorizationError, Request as OxideAuthorizationRequest},
15
        refresh::Error as RefreshTokenError,
16
    },
17
    primitives::{
18
        grant::{Extensions, Grant},
19
        scope::Scope,
20
    },
21
};
22
use oxide_auth_async::code_grant::{self, access_token::Endpoint as TokenEndpoint};
23
use serde_urlencoded;
24
use tera::{Context, Tera};
25
use url::Url;
26

            
27
use sylvia_iot_corelib::{constants::ContentType, err::E_UNKNOWN, http::Json, strings};
28

            
29
use super::{
30
    super::State as AppState,
31
    endpoint::Endpoint,
32
    request::{
33
        self, AccessTokenRequest, AuthorizationRequest, GetAuthRequest, GetLoginRequest,
34
        PostLoginRequest, RefreshTokenRequest,
35
    },
36
    response::OAuth2Error,
37
};
38
use crate::models::{
39
    client::QueryCond as ClientQueryCond,
40
    login_session::{self, LoginSession, QueryCond as SessionQueryCond},
41
    user::QueryCond as UserQueryCond,
42
    Model,
43
};
44

            
45
pub const TMPL_LOGIN: &'static str = "login";
46
pub const TMPL_GRANT: &'static str = "grant";
47

            
48
/// `GET /{base}/oauth2/auth`
49
///
50
/// Authenticate client and redirect to the login page.
51
32
pub async fn get_auth(State(state): State<AppState>, req: GetAuthRequest) -> impl IntoResponse {
52
    const FN_NAME: &'static str = "get_auth";
53

            
54
32
    if let Err(resp) = check_auth_params(FN_NAME, &req, &state.model).await {
55
24
        return resp;
56
8
    }
57

            
58
8
    let login_state: String = match serde_urlencoded::to_string(&req) {
59
        Err(e) => {
60
            let err_str = e.to_string();
61
            error!(
62
                "[{}] encode authorize state error: {}",
63
                FN_NAME,
64
                err_str.as_str()
65
            );
66
            return redirect_server_error(
67
                FN_NAME,
68
                req.redirect_uri.as_str(),
69
                Some(err_str.as_str()),
70
            );
71
        }
72
8
        Ok(str) => match serde_urlencoded::to_string(GetLoginRequest { state: str }) {
73
            Err(e) => {
74
                let err_str = e.to_string();
75
                error!(
76
                    "[{}] encode login state error: {}",
77
                    FN_NAME,
78
                    err_str.as_str()
79
                );
80
                return redirect_server_error(
81
                    FN_NAME,
82
                    req.redirect_uri.as_str(),
83
                    Some(err_str.as_str()),
84
                );
85
            }
86
8
            Ok(str) => str,
87
8
        },
88
8
    };
89
8
    resp_found(format!("{}/oauth2/login?{}", state.scope_path, login_state))
90
32
}
91

            
92
/// `GET /{base}/oauth2/login`
93
///
94
/// To render the login page.
95
24
pub async fn get_login(
96
24
    State(state): State<AppState>,
97
24
    Extension(tera): Extension<Tera>,
98
24
    req: GetLoginRequest,
99
24
) -> impl IntoResponse {
100
    const FN_NAME: &'static str = "get_login";
101

            
102
24
    if req.state.as_str().len() == 0 {
103
4
        warn!("[{}] empty state content", FN_NAME);
104
4
        return resp_invalid_request(Some("invalid state content"));
105
20
    }
106
20
    match serde_urlencoded::from_str::<GetAuthRequest>(req.state.as_str()) {
107
4
        Err(e) => {
108
4
            warn!(
109
                "[{}] parse state error: {}, content: {}",
110
                FN_NAME,
111
                e,
112
                req.state.as_str()
113
            );
114
4
            return resp_invalid_request(Some("invalid state content"));
115
        }
116
16
        Ok(inner_req) => {
117
16
            if let Err(resp) = check_auth_params(FN_NAME, &inner_req, &state.model).await {
118
8
                return resp;
119
8
            }
120
8
        }
121
8
    }
122
8

            
123
8
    let mut context = Context::new();
124
8
    context.insert("scope_path", &state.scope_path);
125
8
    context.insert("state", &req.state);
126
8
    let page = match tera.render(TMPL_LOGIN, &context) {
127
        Err(e) => {
128
            let err_str = e.to_string();
129
            error!(
130
                "[{}] render login template error: {}",
131
                FN_NAME,
132
                err_str.as_str()
133
            );
134
            return resp_temporary_unavailable(Some(err_str));
135
        }
136
8
        Ok(page) => page,
137
8
    };
138
8

            
139
8
    ([(header::CONTENT_TYPE, "text/html; charset=utf-8")], page).into_response()
140
24
}
141

            
142
/// `POST /{base}/oauth2/login`
143
///
144
/// Do the login process.
145
592
pub async fn post_login(State(state): State<AppState>, req: PostLoginRequest) -> impl IntoResponse {
146
    const FN_NAME: &'static str = "post_login";
147

            
148
592
    if req.state.as_str().len() == 0 {
149
4
        warn!("[{}] empty state content", FN_NAME);
150
4
        return resp_invalid_request(Some("invalid state content"));
151
588
    }
152
588
    match serde_urlencoded::from_str::<GetAuthRequest>(req.state.as_str()) {
153
4
        Err(e) => {
154
4
            warn!(
155
                "[{}] parse state error: {}, content: {}",
156
                FN_NAME,
157
                e,
158
                req.state.as_str()
159
            );
160
4
            return resp_invalid_request(Some("invalid state content"));
161
        }
162
584
        Ok(inner_req) => {
163
584
            if let Err(resp) = check_auth_params(FN_NAME, &inner_req, &state.model).await {
164
8
                return resp;
165
576
            }
166
576
        }
167
576
    }
168
576

            
169
576
    let user_cond = UserQueryCond {
170
576
        user_id: None,
171
576
        account: Some(req.account.as_str()),
172
576
    };
173
576
    let user_id = match state.model.user().get(&user_cond).await {
174
        Err(e) => {
175
            let err_str = e.to_string();
176
            error!("[{}] get user DB error: {}", FN_NAME, err_str.as_str());
177
            return resp_temporary_unavailable(Some(err_str));
178
        }
179
576
        Ok(user) => match user {
180
            None => {
181
4
                return resp_invalid_auth(None);
182
            }
183
572
            Some(user) => {
184
572
                let hash = strings::password_hash(req.password.as_str(), user.salt.as_str());
185
572
                if user.password != hash {
186
4
                    return resp_invalid_auth(None);
187
568
                }
188
568
                user.user_id
189
            }
190
        },
191
    };
192

            
193
568
    let session = LoginSession {
194
568
        session_id: strings::random_id_sha(&Utc::now(), 4),
195
568
        expires_at: match TimeDelta::try_seconds(login_session::EXPIRES) {
196
            None => panic!("{}", E_UNKNOWN),
197
568
            Some(t) => Utc::now() + t,
198
568
        },
199
568
        user_id,
200
    };
201
568
    if let Err(e) = state.model.login_session().add(&session).await {
202
        let err_str = e.to_string();
203
        error!("[{}] add login session error: {}", FN_NAME, e);
204
        return resp_temporary_unavailable(Some(err_str));
205
568
    }
206
568
    resp_found(format!(
207
568
        "{}/oauth2/authorize?{}&session_id={}",
208
568
        state.scope_path, req.state, session.session_id
209
568
    ))
210
592
}
211

            
212
/// `GET /{base}/oauth2/authorize` and `POST /{base}/oauth2/authorize`
213
///
214
/// To render the OAuth2 grant page or to authorize the client and grant.
215
620
pub async fn authorize(
216
620
    State(state): State<AppState>,
217
620
    Extension(tera): Extension<Tera>,
218
620
    req: AuthorizationRequest,
219
620
) -> impl IntoResponse {
220
    const FN_NAME: &'static str = "authorize";
221

            
222
620
    let mut endpoint = Endpoint::new(state.model.clone(), None);
223
620
    let pending = match code_grant::authorization::authorization_code(&mut endpoint, &req).await {
224
40
        Err(e) => match e {
225
            AuthorizationError::Ignore => {
226
16
                return resp_invalid_request(None);
227
            }
228
24
            AuthorizationError::Redirect(url) => {
229
24
                let url: Url = url.into();
230
24
                return resp_found(url.to_string());
231
            }
232
            AuthorizationError::PrimitiveError => {
233
                error!("[{}] authorize() with primitive error", FN_NAME);
234
                return resp_temporary_unavailable(None);
235
            }
236
        },
237
580
        Ok(pending) => pending,
238
    };
239

            
240
580
    if let Some(allowed) = req.allowed() {
241
560
        match allowed {
242
            false => {
243
8
                if let Err(e) = pending.deny() {
244
8
                    match e {
245
8
                        AuthorizationError::Redirect(url) => {
246
8
                            let url: Url = url.into();
247
8
                            return resp_found(url.to_string());
248
                        }
249
                        _ => (),
250
                    }
251
                }
252
                let e = OAuth2Error::new("server_error", Some("deny error".to_string()));
253
                return (StatusCode::INTERNAL_SERVER_ERROR, Json(e)).into_response();
254
            }
255
            true => {
256
552
                let session_id = req.session_id();
257
552
                let user_id = match state.model.login_session().get(session_id).await {
258
                    Err(e) => {
259
                        error!("[{}] authorize() with primitive error: {}", FN_NAME, e);
260
                        return resp_temporary_unavailable(None);
261
                    }
262
552
                    Ok(session) => match session {
263
                        None => {
264
4
                            warn!("[{}] authorize() with invalid session ID", FN_NAME);
265
4
                            return resp_invalid_auth(None);
266
                        }
267
548
                        Some(session) => session.user_id,
268
548
                    },
269
548
                };
270
548
                let cond = SessionQueryCond {
271
548
                    session_id: Some(session_id),
272
548
                    ..Default::default()
273
548
                };
274
548
                if let Err(e) = state.model.login_session().del(&cond).await {
275
                    error!("[{}] authorize() remove session ID error: {}", FN_NAME, e);
276
                    return resp_temporary_unavailable(None);
277
548
                }
278
548
                match pending.authorize(&mut endpoint, Cow::from(user_id)).await {
279
                    Err(_) => {
280
                        error!("[{}] authorize error", FN_NAME);
281
                        return resp_temporary_unavailable(None);
282
                    }
283
548
                    Ok(url) => {
284
548
                        return resp_found(url.to_string());
285
                    }
286
                }
287
            }
288
        }
289
20
    }
290
20

            
291
20
    let client_id = req.client_id().unwrap();
292
20
    let client_cond = ClientQueryCond {
293
20
        user_id: None,
294
20
        client_id: Some(client_id.as_ref()),
295
20
    };
296
20
    let client_name = match state.model.client().get(&client_cond).await {
297
        Err(e) => {
298
            let err_str = e.to_string();
299
            error!("[{}] get client DB error: {}", FN_NAME, err_str.as_str());
300
            return resp_temporary_unavailable(Some(err_str));
301
        }
302
20
        Ok(client) => match client {
303
            None => {
304
                return resp_invalid_request(Some("invalid client"));
305
            }
306
20
            Some(client) => client.name,
307
20
        },
308
20
    };
309
20

            
310
20
    let mut context = Context::new();
311
20
    context.insert("scope_path", &state.scope_path);
312
20
    context.insert("client_name", &client_name);
313
20
    context.insert("session_id", req.session_id());
314
20
    context.insert("client_id", client_id.as_ref());
315
20
    context.insert("response_type", req.response_type().unwrap().as_ref());
316
20
    context.insert("redirect_uri", req.redirect_uri().unwrap().as_ref());
317
20
    context.insert("allow_value", request::ALLOW_VALUE);
318
20
    let scope = req.scope();
319
20
    if let Some(ref scope) = scope {
320
12
        context.insert("scope", scope);
321
12
    }
322
20
    let state = req.state();
323
20
    if let Some(ref state) = state {
324
4
        context.insert("state", state);
325
16
    }
326
20
    let page = match tera.render(TMPL_GRANT, &context) {
327
        Err(e) => {
328
            let err_str = e.to_string();
329
            error!("[{}] get client DB error: {}", FN_NAME, err_str.as_str());
330
            return resp_temporary_unavailable(Some(err_str));
331
        }
332
20
        Ok(page) => page,
333
20
    };
334
20
    ([(header::CONTENT_TYPE, "text/html; charset=utf-8")], page).into_response()
335
620
}
336

            
337
/// `POST /{base}/oauth2/token`
338
///
339
/// To generate an access token with the authorization code or client credentials.
340
580
pub async fn post_token(
341
580
    State(state): State<AppState>,
342
580
    req: AccessTokenRequest,
343
580
) -> impl IntoResponse {
344
580
    let mut endpoint = Endpoint::new(state.model.clone(), None);
345

            
346
580
    if let Some(grant_type) = req.grant_type() {
347
580
        if grant_type.eq("client_credentials") {
348
32
            return client_credentials_token(&req, &state, &mut endpoint).await;
349
548
        }
350
    }
351

            
352
548
    let token = match code_grant::access_token::access_token(&mut endpoint, &req).await {
353
52
        Err(e) => match e {
354
36
            AccessTokenError::Invalid(desc) => {
355
36
                return (
356
36
                    StatusCode::BAD_REQUEST,
357
36
                    [(header::CONTENT_TYPE, ContentType::JSON)],
358
36
                    desc.to_json(),
359
36
                )
360
36
                    .into_response();
361
            }
362
16
            AccessTokenError::Unauthorized(desc, authtype) => {
363
16
                return (
364
16
                    StatusCode::UNAUTHORIZED,
365
16
                    [
366
16
                        (header::CONTENT_TYPE, ContentType::JSON),
367
16
                        (header::WWW_AUTHENTICATE, authtype.as_str()),
368
16
                    ],
369
16
                    desc.to_json(),
370
16
                )
371
16
                    .into_response();
372
            }
373
            // TODO: handle this
374
            AccessTokenError::Primitive(_e) => {
375
                return StatusCode::SERVICE_UNAVAILABLE.into_response()
376
            }
377
        },
378
496
        Ok(token) => token,
379
496
    };
380
496
    ([(header::CONTENT_TYPE, ContentType::JSON)], token.to_json()).into_response()
381
580
}
382

            
383
/// `POST /{base}/oauth2/refresh`
384
///
385
/// To refresh an access token.
386
68
pub async fn post_refresh(
387
68
    State(state): State<AppState>,
388
68
    req: RefreshTokenRequest,
389
68
) -> impl IntoResponse {
390
68
    let mut endpoint = Endpoint::new(state.model.clone(), None);
391
68
    let token = match code_grant::refresh::refresh(&mut endpoint, &req).await {
392
44
        Err(e) => match e {
393
24
            RefreshTokenError::Invalid(desc) => {
394
24
                return (
395
24
                    StatusCode::BAD_REQUEST,
396
24
                    [(header::CONTENT_TYPE, ContentType::JSON)],
397
24
                    desc.to_json(),
398
24
                )
399
24
                    .into_response();
400
            }
401
20
            RefreshTokenError::Unauthorized(desc, authtype) => {
402
20
                return (
403
20
                    StatusCode::UNAUTHORIZED,
404
20
                    [
405
20
                        (header::CONTENT_TYPE, ContentType::JSON),
406
20
                        (header::WWW_AUTHENTICATE, authtype.as_str()),
407
20
                    ],
408
20
                    desc.to_json(),
409
20
                )
410
20
                    .into_response();
411
            }
412
            RefreshTokenError::Primitive => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
413
        },
414
24
        Ok(token) => token,
415
24
    };
416
24
    ([(header::CONTENT_TYPE, ContentType::JSON)], token.to_json()).into_response()
417
68
}
418

            
419
32
async fn client_credentials_token(
420
32
    req: &AccessTokenRequest,
421
32
    state: &AppState,
422
32
    endpoint: &mut Endpoint,
423
32
) -> Response {
424
    // Validate the client.
425
32
    let (client_id, client_secret) = match req.authorization() {
426
4
        Authorization::None => return resp_invalid_request(None),
427
12
        Authorization::Username(_) => return resp_invalid_client(None),
428
16
        Authorization::UsernamePassword(user, pass) => (user, pass),
429
        _ => return resp_invalid_request(None),
430
    };
431
16
    let cond = ClientQueryCond {
432
16
        client_id: Some(client_id.as_ref()),
433
16
        ..Default::default()
434
16
    };
435
16
    let client = match state.model.client().get(&cond).await {
436
        Err(e) => return resp_temporary_unavailable(Some(format!("{}", e))),
437
16
        Ok(client) => match client {
438
            None => return resp_invalid_client(None),
439
16
            Some(client) => client,
440
16
        },
441
16
    };
442
16
    match client.client_secret.as_ref() {
443
        None => return resp_invalid_client(None),
444
16
        Some(secret) => match secret.as_bytes().eq(client_secret.as_ref()) {
445
4
            false => return resp_invalid_client(None),
446
12
            true => (),
447
        },
448
    }
449

            
450
    // Reuse the issuer to generate tokens.
451
4
    let grant = Grant {
452
12
        owner_id: client.user_id,
453
12
        client_id: client.client_id,
454
12
        scope: match client.scopes.as_slice().join(" ").parse() {
455
            Err(_) => return resp_invalid_client(Some("no valid scope")),
456
12
            Ok(scope) => scope,
457
12
        },
458
12
        redirect_uri: match client.redirect_uris.get(0) {
459
4
            None => return resp_invalid_client(Some("no valid redirect_uri")),
460
8
            Some(uri) => match Url::parse(uri.as_str()) {
461
4
                Err(_) => return resp_invalid_client(Some("invalid redirect_uri")),
462
4
                Ok(uri) => uri,
463
4
            },
464
4
        },
465
4
        until: match TimeDelta::try_minutes(10) {
466
            None => panic!("{}", E_UNKNOWN),
467
4
            Some(t) => Utc::now() + t,
468
4
        },
469
4
        extensions: Extensions::new(),
470
    };
471
4
    let token = match endpoint.issuer().issue(grant).await {
472
        Err(_) => return resp_temporary_unavailable(None),
473
4
        Ok(token) => token,
474
4
    };
475
4

            
476
4
    Json(TokenResponse {
477
4
        access_token: Some(token.token),
478
4
        refresh_token: token.refresh,
479
4
        token_type: Some("bearer".to_string()),
480
4
        expires_in: Some(token.until.signed_duration_since(Utc::now()).num_seconds()),
481
4
        scope: Some(client.scopes.as_slice().join(" ")),
482
4
        error: None,
483
4
    })
484
4
    .into_response()
485
32
}
486

            
487
/// To check the authorization grant flow parameters.
488
632
async fn check_auth_params(
489
632
    fn_name: &str,
490
632
    req: &GetAuthRequest,
491
632
    model: &Arc<dyn Model>,
492
632
) -> Result<(), Response> {
493
632
    if req.response_type != "code" {
494
12
        return Err(resp_invalid_request(Some("unsupport response_type")));
495
620
    }
496
620

            
497
620
    let client_cond = ClientQueryCond {
498
620
        user_id: None,
499
620
        client_id: Some(req.client_id.as_str()),
500
620
    };
501
620
    match model.client().get(&client_cond).await {
502
        Err(e) => {
503
            let err_str = e.to_string();
504
            error!("[{}] get client DB error: {}", fn_name, err_str.as_str());
505
            return Err(resp_temporary_unavailable(Some(err_str)));
506
        }
507
620
        Ok(client) => match client {
508
            None => {
509
4
                return Err(resp_invalid_request(Some("invalid client")));
510
            }
511
616
            Some(client) => {
512
616
                if !client.redirect_uris.contains(&req.redirect_uri) {
513
4
                    return Err(resp_invalid_request(Some("invalid redirect_uri")));
514
612
                } else if client.scopes.len() > 0 {
515
36
                    if req.scope.is_none() {
516
12
                        return Err(redirect_invalid_scope(&req.redirect_uri));
517
24
                    }
518
24
                    let req_scopes = match req.scope.as_ref().unwrap().parse::<Scope>() {
519
4
                        Err(_e) => {
520
4
                            // TODO: handle this with the error reason.
521
4
                            return Err(redirect_invalid_scope(&req.redirect_uri));
522
                        }
523
20
                        Ok(scopes) => scopes,
524
                    };
525
20
                    let client_scopes = match client.scopes.join(" ").parse::<Scope>() {
526
                        Err(e) => {
527
                            error!("[{}] parse client scopes error: {}", fn_name, e);
528
                            return Err(redirect_server_error(fn_name, &req.redirect_uri, None));
529
                        }
530
20
                        Ok(scopes) => scopes,
531
20
                    };
532
20
                    if !req_scopes.allow_access(&client_scopes) {
533
4
                        return Err(redirect_invalid_scope(&req.redirect_uri));
534
16
                    }
535
576
                }
536
            }
537
        },
538
    }
539
592
    Ok(())
540
632
}
541

            
542
20
fn redirect_invalid_scope(redirect_uri: &str) -> Response {
543
20
    resp_found(format!("{}?error=invalid_scope", redirect_uri))
544
20
}
545

            
546
fn redirect_server_error(fn_name: &str, redirect_uri: &str, description: Option<&str>) -> Response {
547
    let location = match description {
548
        None => format!("{}?error=server_error", redirect_uri),
549
        Some(desc) => {
550
            let err_desc = [("error_description", desc)];
551
            match serde_urlencoded::to_string(&err_desc) {
552
                Err(e) => {
553
                    error!("[{}] redirect server_error encode error: {}", fn_name, e);
554
                    format!("{}?error=server_error", redirect_uri)
555
                }
556
                Ok(qs) => format!("{}?error=server_error&{}", redirect_uri, qs),
557
            }
558
        }
559
    };
560
    resp_found(location)
561
}
562

            
563
1176
fn resp_found(location: String) -> Response {
564
1176
    (StatusCode::FOUND, [(header::LOCATION, location)]).into_response()
565
1176
}
566

            
567
12
fn resp_invalid_auth<'a>(description: Option<&'a str>) -> Response {
568
12
    let description = match description {
569
12
        None => None,
570
        Some(description) => Some(description.to_string()),
571
    };
572
12
    (
573
12
        StatusCode::BAD_REQUEST,
574
12
        Json(OAuth2Error::new("invalid_auth", description)),
575
12
    )
576
12
        .into_response()
577
12
}
578

            
579
24
fn resp_invalid_client<'a>(description: Option<&'a str>) -> Response {
580
24
    let description = match description {
581
16
        None => None,
582
8
        Some(description) => Some(description.to_string()),
583
    };
584
24
    (
585
24
        StatusCode::UNAUTHORIZED,
586
24
        Json(OAuth2Error::new("invalid_client", description)),
587
24
    )
588
24
        .into_response()
589
24
}
590

            
591
56
fn resp_invalid_request<'a>(description: Option<&'a str>) -> Response {
592
56
    let description = match description {
593
20
        None => None,
594
36
        Some(description) => Some(description.to_string()),
595
    };
596
56
    (
597
56
        StatusCode::BAD_REQUEST,
598
56
        Json(OAuth2Error::new("invalid_request", description)),
599
56
    )
600
56
        .into_response()
601
56
}
602

            
603
fn resp_temporary_unavailable(description: Option<String>) -> Response {
604
    (
605
        StatusCode::SERVICE_UNAVAILABLE,
606
        Json(OAuth2Error::new("temporarily_unavailable", description)),
607
    )
608
        .into_response()
609
}