1
use std::{borrow::Cow, sync::Arc};
2

            
3
use axum::{
4
    extract::State,
5
    http::{header, StatusCode},
6
    response::{IntoResponse, Response},
7
    Extension,
8
};
9
use chrono::{TimeDelta, Utc};
10
use log::{error, warn};
11
use oxide_auth::{
12
    code_grant::{
13
        accesstoken::{Authorization, Error as AccessTokenError, Request, TokenResponse},
14
        authorization::{Error as AuthorizationError, Request as OxideAuthorizationRequest},
15
        refresh::Error as RefreshTokenError,
16
    },
17
    primitives::{
18
        grant::{Extensions, Grant},
19
        scope::Scope,
20
    },
21
};
22
use oxide_auth_async::code_grant::{self, access_token::Endpoint as TokenEndpoint};
23
use serde_urlencoded;
24
use tera::{Context, Tera};
25
use url::Url;
26

            
27
use sylvia_iot_corelib::{constants::ContentType, err::E_UNKNOWN, http::Json, strings};
28

            
29
use super::{
30
    super::State as AppState,
31
    endpoint::Endpoint,
32
    request::{
33
        self, AccessTokenRequest, AuthorizationRequest, GetAuthRequest, GetLoginRequest,
34
        PostLoginRequest, RefreshTokenRequest,
35
    },
36
    response::OAuth2Error,
37
};
38
use crate::models::{
39
    client::QueryCond as ClientQueryCond,
40
    login_session::{self, LoginSession, QueryCond as SessionQueryCond},
41
    user::QueryCond as UserQueryCond,
42
    Model,
43
};
44

            
45
pub const TMPL_LOGIN: &'static str = "login";
46
pub const TMPL_GRANT: &'static str = "grant";
47

            
48
/// `GET /{base}/oauth2/auth`
49
///
50
/// Authenticate client and redirect to the login page.
51
16
pub async fn get_auth(State(state): State<AppState>, req: GetAuthRequest) -> impl IntoResponse {
52
    const FN_NAME: &'static str = "get_auth";
53

            
54
16
    if let Err(resp) = check_auth_params(FN_NAME, &req, &state.model).await {
55
12
        return resp;
56
4
    }
57

            
58
4
    let login_state: String = match serde_urlencoded::to_string(&req) {
59
        Err(e) => {
60
            let err_str = e.to_string();
61
            error!(
62
                "[{}] encode authorize state error: {}",
63
                FN_NAME,
64
                err_str.as_str()
65
            );
66
            return redirect_server_error(
67
                FN_NAME,
68
                req.redirect_uri.as_str(),
69
                Some(err_str.as_str()),
70
            );
71
        }
72
4
        Ok(str) => match serde_urlencoded::to_string(GetLoginRequest { state: str }) {
73
            Err(e) => {
74
                let err_str = e.to_string();
75
                error!(
76
                    "[{}] encode login state error: {}",
77
                    FN_NAME,
78
                    err_str.as_str()
79
                );
80
                return redirect_server_error(
81
                    FN_NAME,
82
                    req.redirect_uri.as_str(),
83
                    Some(err_str.as_str()),
84
                );
85
            }
86
4
            Ok(str) => str,
87
4
        },
88
4
    };
89
4
    resp_found(format!("{}/oauth2/login?{}", state.scope_path, login_state))
90
16
}
91

            
92
/// `GET /{base}/oauth2/login`
93
///
94
/// To render the login page.
95
12
pub async fn get_login(
96
12
    State(state): State<AppState>,
97
12
    Extension(tera): Extension<Tera>,
98
12
    req: GetLoginRequest,
99
12
) -> impl IntoResponse {
100
    const FN_NAME: &'static str = "get_login";
101

            
102
12
    if req.state.as_str().len() == 0 {
103
2
        warn!("[{}] empty state content", FN_NAME);
104
2
        return resp_invalid_request(Some("invalid state content"));
105
10
    }
106
10
    match serde_urlencoded::from_str::<GetAuthRequest>(req.state.as_str()) {
107
2
        Err(e) => {
108
2
            warn!(
109
                "[{}] parse state error: {}, content: {}",
110
                FN_NAME,
111
                e,
112
                req.state.as_str()
113
            );
114
2
            return resp_invalid_request(Some("invalid state content"));
115
        }
116
8
        Ok(inner_req) => {
117
8
            if let Err(resp) = check_auth_params(FN_NAME, &inner_req, &state.model).await {
118
4
                return resp;
119
4
            }
120
4
        }
121
4
    }
122
4

            
123
4
    let mut context = Context::new();
124
4
    context.insert("scope_path", &state.scope_path);
125
4
    context.insert("state", &req.state);
126
4
    let page = match tera.render(TMPL_LOGIN, &context) {
127
        Err(e) => {
128
            let err_str = e.to_string();
129
            error!(
130
                "[{}] render login template error: {}",
131
                FN_NAME,
132
                err_str.as_str()
133
            );
134
            return resp_temporary_unavailable(Some(err_str));
135
        }
136
4
        Ok(page) => page,
137
4
    };
138
4

            
139
4
    ([(header::CONTENT_TYPE, "text/html; charset=utf-8")], page).into_response()
140
12
}
141

            
142
/// `POST /{base}/oauth2/login`
143
///
144
/// Do the login process.
145
296
pub async fn post_login(State(state): State<AppState>, req: PostLoginRequest) -> impl IntoResponse {
146
    const FN_NAME: &'static str = "post_login";
147

            
148
296
    if req.state.as_str().len() == 0 {
149
2
        warn!("[{}] empty state content", FN_NAME);
150
2
        return resp_invalid_request(Some("invalid state content"));
151
294
    }
152
294
    match serde_urlencoded::from_str::<GetAuthRequest>(req.state.as_str()) {
153
2
        Err(e) => {
154
2
            warn!(
155
                "[{}] parse state error: {}, content: {}",
156
                FN_NAME,
157
                e,
158
                req.state.as_str()
159
            );
160
2
            return resp_invalid_request(Some("invalid state content"));
161
        }
162
292
        Ok(inner_req) => {
163
292
            if let Err(resp) = check_auth_params(FN_NAME, &inner_req, &state.model).await {
164
4
                return resp;
165
288
            }
166
288
        }
167
288
    }
168
288

            
169
288
    let user_cond = UserQueryCond {
170
288
        user_id: None,
171
288
        account: Some(req.account.as_str()),
172
288
    };
173
288
    let user_id = match state.model.user().get(&user_cond).await {
174
        Err(e) => {
175
            let err_str = e.to_string();
176
            error!("[{}] get user DB error: {}", FN_NAME, err_str.as_str());
177
            return resp_temporary_unavailable(Some(err_str));
178
        }
179
288
        Ok(user) => match user {
180
            None => {
181
2
                return resp_invalid_auth(None);
182
            }
183
286
            Some(user) => {
184
286
                let hash = strings::password_hash(req.password.as_str(), user.salt.as_str());
185
286
                if user.password != hash {
186
2
                    return resp_invalid_auth(None);
187
284
                }
188
284
                user.user_id
189
            }
190
        },
191
    };
192

            
193
284
    let session = LoginSession {
194
284
        session_id: strings::random_id_sha(&Utc::now(), 4),
195
284
        expires_at: match TimeDelta::try_seconds(login_session::EXPIRES) {
196
            None => panic!("{}", E_UNKNOWN),
197
284
            Some(t) => Utc::now() + t,
198
284
        },
199
284
        user_id,
200
    };
201
284
    if let Err(e) = state.model.login_session().add(&session).await {
202
        let err_str = e.to_string();
203
        error!("[{}] add login session error: {}", FN_NAME, e);
204
        return resp_temporary_unavailable(Some(err_str));
205
284
    }
206
284
    resp_found(format!(
207
284
        "{}/oauth2/authorize?{}&session_id={}",
208
284
        state.scope_path, req.state, session.session_id
209
284
    ))
210
296
}
211

            
212
/// `GET /{base}/oauth2/authorize` and `POST /{base}/oauth2/authorize`
213
///
214
/// To render the OAuth2 grant page or to authorize the client and grant.
215
310
pub async fn authorize(
216
310
    State(state): State<AppState>,
217
310
    Extension(tera): Extension<Tera>,
218
310
    req: AuthorizationRequest,
219
310
) -> impl IntoResponse {
220
    const FN_NAME: &'static str = "authorize";
221

            
222
310
    let mut endpoint = Endpoint::new(state.model.clone(), None);
223
310
    let pending = match code_grant::authorization::authorization_code(&mut endpoint, &req).await {
224
20
        Err(e) => match e {
225
            AuthorizationError::Ignore => {
226
8
                return resp_invalid_request(None);
227
            }
228
12
            AuthorizationError::Redirect(url) => {
229
12
                let url: Url = url.into();
230
12
                return resp_found(url.to_string());
231
            }
232
            AuthorizationError::PrimitiveError => {
233
                error!("[{}] authorize() with primitive error", FN_NAME);
234
                return resp_temporary_unavailable(None);
235
            }
236
        },
237
290
        Ok(pending) => pending,
238
    };
239

            
240
290
    if let Some(allowed) = req.allowed() {
241
280
        match allowed {
242
            false => {
243
4
                if let Err(e) = pending.deny() {
244
4
                    match e {
245
4
                        AuthorizationError::Redirect(url) => {
246
4
                            let url: Url = url.into();
247
4
                            return resp_found(url.to_string());
248
                        }
249
                        _ => (),
250
                    }
251
                }
252
                let e = OAuth2Error::new("server_error", Some("deny error".to_string()));
253
                return (StatusCode::INTERNAL_SERVER_ERROR, Json(e)).into_response();
254
            }
255
            true => {
256
276
                let session_id = req.session_id();
257
276
                let user_id = match state.model.login_session().get(session_id).await {
258
                    Err(e) => {
259
                        error!("[{}] authorize() with primitive error: {}", FN_NAME, e);
260
                        return resp_temporary_unavailable(None);
261
                    }
262
276
                    Ok(session) => match session {
263
                        None => {
264
2
                            warn!("[{}] authorize() with invalid session ID", FN_NAME);
265
2
                            return resp_invalid_auth(None);
266
                        }
267
274
                        Some(session) => session.user_id,
268
274
                    },
269
274
                };
270
274
                let cond = SessionQueryCond {
271
274
                    session_id: Some(session_id),
272
274
                    ..Default::default()
273
274
                };
274
274
                if let Err(e) = state.model.login_session().del(&cond).await {
275
                    error!("[{}] authorize() remove session ID error: {}", FN_NAME, e);
276
                    return resp_temporary_unavailable(None);
277
274
                }
278
274
                match pending.authorize(&mut endpoint, Cow::from(user_id)).await {
279
                    Err(_) => {
280
                        error!("[{}] authorize error", FN_NAME);
281
                        return resp_temporary_unavailable(None);
282
                    }
283
274
                    Ok(url) => {
284
274
                        return resp_found(url.to_string());
285
                    }
286
                }
287
            }
288
        }
289
10
    }
290
10

            
291
10
    let client_id = req.client_id().unwrap();
292
10
    let client_cond = ClientQueryCond {
293
10
        user_id: None,
294
10
        client_id: Some(client_id.as_ref()),
295
10
    };
296
10
    let client_name = match state.model.client().get(&client_cond).await {
297
        Err(e) => {
298
            let err_str = e.to_string();
299
            error!("[{}] get client DB error: {}", FN_NAME, err_str.as_str());
300
            return resp_temporary_unavailable(Some(err_str));
301
        }
302
10
        Ok(client) => match client {
303
            None => {
304
                return resp_invalid_request(Some("invalid client"));
305
            }
306
10
            Some(client) => client.name,
307
10
        },
308
10
    };
309
10

            
310
10
    let mut context = Context::new();
311
10
    context.insert("scope_path", &state.scope_path);
312
10
    context.insert("client_name", &client_name);
313
10
    context.insert("session_id", req.session_id());
314
10
    context.insert("client_id", client_id.as_ref());
315
10
    context.insert("response_type", req.response_type().unwrap().as_ref());
316
10
    context.insert("redirect_uri", req.redirect_uri().unwrap().as_ref());
317
10
    context.insert("allow_value", request::ALLOW_VALUE);
318
10
    let scope = req.scope();
319
10
    if let Some(ref scope) = scope {
320
6
        context.insert("scope", scope);
321
6
    }
322
10
    let state = req.state();
323
10
    if let Some(ref state) = state {
324
2
        context.insert("state", state);
325
8
    }
326
10
    let page = match tera.render(TMPL_GRANT, &context) {
327
        Err(e) => {
328
            let err_str = e.to_string();
329
            error!("[{}] get client DB error: {}", FN_NAME, err_str.as_str());
330
            return resp_temporary_unavailable(Some(err_str));
331
        }
332
10
        Ok(page) => page,
333
10
    };
334
10
    ([(header::CONTENT_TYPE, "text/html; charset=utf-8")], page).into_response()
335
310
}
336

            
337
/// `POST /{base}/oauth2/token`
338
///
339
/// To generate an access token with the authorization code or client credentials.
340
290
pub async fn post_token(
341
290
    State(state): State<AppState>,
342
290
    req: AccessTokenRequest,
343
290
) -> impl IntoResponse {
344
290
    let mut endpoint = Endpoint::new(state.model.clone(), None);
345

            
346
290
    if let Some(grant_type) = req.grant_type() {
347
290
        if grant_type.eq("client_credentials") {
348
16
            return client_credentials_token(&req, &state, &mut endpoint).await;
349
274
        }
350
    }
351

            
352
274
    let token = match code_grant::access_token::access_token(&mut endpoint, &req).await {
353
26
        Err(e) => match e {
354
18
            AccessTokenError::Invalid(desc) => {
355
18
                return (
356
18
                    StatusCode::BAD_REQUEST,
357
18
                    [(header::CONTENT_TYPE, ContentType::JSON)],
358
18
                    desc.to_json(),
359
18
                )
360
18
                    .into_response();
361
            }
362
8
            AccessTokenError::Unauthorized(desc, authtype) => {
363
8
                return (
364
8
                    StatusCode::UNAUTHORIZED,
365
8
                    [
366
8
                        (header::CONTENT_TYPE, ContentType::JSON),
367
8
                        (header::WWW_AUTHENTICATE, authtype.as_str()),
368
8
                    ],
369
8
                    desc.to_json(),
370
8
                )
371
8
                    .into_response();
372
            }
373
            // TODO: handle this
374
            AccessTokenError::Primitive(_e) => {
375
                return StatusCode::SERVICE_UNAVAILABLE.into_response()
376
            }
377
        },
378
248
        Ok(token) => token,
379
248
    };
380
248
    ([(header::CONTENT_TYPE, ContentType::JSON)], token.to_json()).into_response()
381
290
}
382

            
383
/// `POST /{base}/oauth2/refresh`
384
///
385
/// To refresh an access token.
386
34
pub async fn post_refresh(
387
34
    State(state): State<AppState>,
388
34
    req: RefreshTokenRequest,
389
34
) -> impl IntoResponse {
390
34
    let mut endpoint = Endpoint::new(state.model.clone(), None);
391
34
    let token = match code_grant::refresh::refresh(&mut endpoint, &req).await {
392
22
        Err(e) => match e {
393
12
            RefreshTokenError::Invalid(desc) => {
394
12
                return (
395
12
                    StatusCode::BAD_REQUEST,
396
12
                    [(header::CONTENT_TYPE, ContentType::JSON)],
397
12
                    desc.to_json(),
398
12
                )
399
12
                    .into_response();
400
            }
401
10
            RefreshTokenError::Unauthorized(desc, authtype) => {
402
10
                return (
403
10
                    StatusCode::UNAUTHORIZED,
404
10
                    [
405
10
                        (header::CONTENT_TYPE, ContentType::JSON),
406
10
                        (header::WWW_AUTHENTICATE, authtype.as_str()),
407
10
                    ],
408
10
                    desc.to_json(),
409
10
                )
410
10
                    .into_response();
411
            }
412
            RefreshTokenError::Primitive => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
413
        },
414
12
        Ok(token) => token,
415
12
    };
416
12
    ([(header::CONTENT_TYPE, ContentType::JSON)], token.to_json()).into_response()
417
34
}
418

            
419
16
async fn client_credentials_token(
420
16
    req: &AccessTokenRequest,
421
16
    state: &AppState,
422
16
    endpoint: &mut Endpoint,
423
16
) -> Response {
424
    // Validate the client.
425
16
    let (client_id, client_secret) = match req.authorization() {
426
2
        Authorization::None => return resp_invalid_request(None),
427
6
        Authorization::Username(_) => return resp_invalid_client(None),
428
8
        Authorization::UsernamePassword(user, pass) => (user, pass),
429
        _ => return resp_invalid_request(None),
430
    };
431
8
    let cond = ClientQueryCond {
432
8
        client_id: Some(client_id.as_ref()),
433
8
        ..Default::default()
434
8
    };
435
8
    let client = match state.model.client().get(&cond).await {
436
        Err(e) => return resp_temporary_unavailable(Some(format!("{}", e))),
437
8
        Ok(client) => match client {
438
            None => return resp_invalid_client(None),
439
8
            Some(client) => client,
440
8
        },
441
8
    };
442
8
    match client.client_secret.as_ref() {
443
        None => return resp_invalid_client(None),
444
8
        Some(secret) => match secret.as_bytes().eq(client_secret.as_ref()) {
445
2
            false => return resp_invalid_client(None),
446
6
            true => (),
447
        },
448
    }
449

            
450
    // Reuse the issuer to generate tokens.
451
2
    let grant = Grant {
452
6
        owner_id: client.user_id,
453
6
        client_id: client.client_id,
454
6
        scope: match client.scopes.as_slice().join(" ").parse() {
455
            Err(_) => return resp_invalid_client(Some("no valid scope")),
456
6
            Ok(scope) => scope,
457
6
        },
458
6
        redirect_uri: match client.redirect_uris.get(0) {
459
2
            None => return resp_invalid_client(Some("no valid redirect_uri")),
460
4
            Some(uri) => match Url::parse(uri.as_str()) {
461
2
                Err(_) => return resp_invalid_client(Some("invalid redirect_uri")),
462
2
                Ok(uri) => uri,
463
2
            },
464
2
        },
465
2
        until: match TimeDelta::try_minutes(10) {
466
            None => panic!("{}", E_UNKNOWN),
467
2
            Some(t) => Utc::now() + t,
468
2
        },
469
2
        extensions: Extensions::new(),
470
    };
471
2
    let token = match endpoint.issuer().issue(grant).await {
472
        Err(_) => return resp_temporary_unavailable(None),
473
2
        Ok(token) => token,
474
2
    };
475
2

            
476
2
    Json(TokenResponse {
477
2
        access_token: Some(token.token),
478
2
        refresh_token: token.refresh,
479
2
        token_type: Some("bearer".to_string()),
480
2
        expires_in: Some(token.until.signed_duration_since(Utc::now()).num_seconds()),
481
2
        scope: Some(client.scopes.as_slice().join(" ")),
482
2
        error: None,
483
2
    })
484
2
    .into_response()
485
16
}
486

            
487
/// To check the authorization grant flow parameters.
488
316
async fn check_auth_params(
489
316
    fn_name: &str,
490
316
    req: &GetAuthRequest,
491
316
    model: &Arc<dyn Model>,
492
316
) -> Result<(), Response> {
493
316
    if req.response_type != "code" {
494
6
        return Err(resp_invalid_request(Some("unsupport response_type")));
495
310
    }
496
310

            
497
310
    let client_cond = ClientQueryCond {
498
310
        user_id: None,
499
310
        client_id: Some(req.client_id.as_str()),
500
310
    };
501
310
    match model.client().get(&client_cond).await {
502
        Err(e) => {
503
            let err_str = e.to_string();
504
            error!("[{}] get client DB error: {}", fn_name, err_str.as_str());
505
            return Err(resp_temporary_unavailable(Some(err_str)));
506
        }
507
310
        Ok(client) => match client {
508
            None => {
509
2
                return Err(resp_invalid_request(Some("invalid client")));
510
            }
511
308
            Some(client) => {
512
308
                if !client.redirect_uris.contains(&req.redirect_uri) {
513
2
                    return Err(resp_invalid_request(Some("invalid redirect_uri")));
514
306
                } else if client.scopes.len() > 0 {
515
18
                    if req.scope.is_none() {
516
6
                        return Err(redirect_invalid_scope(&req.redirect_uri));
517
12
                    }
518
12
                    let req_scopes = match req.scope.as_ref().unwrap().parse::<Scope>() {
519
2
                        Err(_e) => {
520
2
                            // TODO: handle this with the error reason.
521
2
                            return Err(redirect_invalid_scope(&req.redirect_uri));
522
                        }
523
10
                        Ok(scopes) => scopes,
524
                    };
525
10
                    let client_scopes = match client.scopes.join(" ").parse::<Scope>() {
526
                        Err(e) => {
527
                            error!("[{}] parse client scopes error: {}", fn_name, e);
528
                            return Err(redirect_server_error(fn_name, &req.redirect_uri, None));
529
                        }
530
10
                        Ok(scopes) => scopes,
531
10
                    };
532
10
                    if !req_scopes.allow_access(&client_scopes) {
533
2
                        return Err(redirect_invalid_scope(&req.redirect_uri));
534
8
                    }
535
288
                }
536
            }
537
        },
538
    }
539
296
    Ok(())
540
316
}
541

            
542
10
fn redirect_invalid_scope(redirect_uri: &str) -> Response {
543
10
    resp_found(format!("{}?error=invalid_scope", redirect_uri))
544
10
}
545

            
546
fn redirect_server_error(fn_name: &str, redirect_uri: &str, description: Option<&str>) -> Response {
547
    let location = match description {
548
        None => format!("{}?error=server_error", redirect_uri),
549
        Some(desc) => {
550
            let err_desc = [("error_description", desc)];
551
            match serde_urlencoded::to_string(&err_desc) {
552
                Err(e) => {
553
                    error!("[{}] redirect server_error encode error: {}", fn_name, e);
554
                    format!("{}?error=server_error", redirect_uri)
555
                }
556
                Ok(qs) => format!("{}?error=server_error&{}", redirect_uri, qs),
557
            }
558
        }
559
    };
560
    resp_found(location)
561
}
562

            
563
588
fn resp_found(location: String) -> Response {
564
588
    (StatusCode::FOUND, [(header::LOCATION, location)]).into_response()
565
588
}
566

            
567
6
fn resp_invalid_auth<'a>(description: Option<&'a str>) -> Response {
568
6
    let description = match description {
569
6
        None => None,
570
        Some(description) => Some(description.to_string()),
571
    };
572
6
    (
573
6
        StatusCode::BAD_REQUEST,
574
6
        Json(OAuth2Error::new("invalid_auth", description)),
575
6
    )
576
6
        .into_response()
577
6
}
578

            
579
12
fn resp_invalid_client<'a>(description: Option<&'a str>) -> Response {
580
12
    let description = match description {
581
8
        None => None,
582
4
        Some(description) => Some(description.to_string()),
583
    };
584
12
    (
585
12
        StatusCode::UNAUTHORIZED,
586
12
        Json(OAuth2Error::new("invalid_client", description)),
587
12
    )
588
12
        .into_response()
589
12
}
590

            
591
28
fn resp_invalid_request<'a>(description: Option<&'a str>) -> Response {
592
28
    let description = match description {
593
10
        None => None,
594
18
        Some(description) => Some(description.to_string()),
595
    };
596
28
    (
597
28
        StatusCode::BAD_REQUEST,
598
28
        Json(OAuth2Error::new("invalid_request", description)),
599
28
    )
600
28
        .into_response()
601
28
}
602

            
603
fn resp_temporary_unavailable(description: Option<String>) -> Response {
604
    (
605
        StatusCode::SERVICE_UNAVAILABLE,
606
        Json(OAuth2Error::new("temporarily_unavailable", description)),
607
    )
608
        .into_response()
609
}