1
use std::{borrow::Cow, str};
2

            
3
use axum::{
4
    body::Bytes,
5
    extract::{Form, FromRequest, Query, Request},
6
    http::{Method, header},
7
    response::{IntoResponse, Response},
8
};
9
use base64::{Engine, engine::general_purpose};
10
use oxide_auth::code_grant::{
11
    accesstoken::{Authorization, Request as OxideAccessTokenRequest},
12
    authorization::Request as OxideAuthorizationRequest,
13
    refresh::Request as OxideRefreshTokenRequest,
14
};
15
use serde::{Deserialize, Serialize};
16

            
17
use super::response::OAuth2Error;
18

            
19
#[derive(Deserialize, Serialize)]
20
pub struct GetAuthRequest {
21
    pub response_type: String,
22
    pub client_id: String,
23
    pub redirect_uri: String,
24
    pub scope: Option<String>,
25
    pub state: Option<String>,
26
}
27

            
28
#[derive(Deserialize, Serialize)]
29
pub struct GetLoginRequest {
30
    pub state: String,
31
}
32

            
33
#[derive(Deserialize)]
34
pub struct PostLoginRequest {
35
    pub account: String,
36
    pub password: String,
37
    pub state: String,
38
}
39

            
40
#[derive(Deserialize, Serialize)]
41
pub struct AuthorizationRequest {
42
    response_type: String,
43
    client_id: String,
44
    redirect_uri: String,
45
    scope: Option<String>,
46
    state: Option<String>,
47
    session_id: String,
48
    allow: Option<String>,
49
}
50

            
51
#[derive(Deserialize, Serialize)]
52
pub struct AccessTokenRequest {
53
    #[serde(skip)]
54
    authorization: Option<(String, Vec<u8>)>,
55
    grant_type: String,
56
    code: Option<String>,         // for authorization code grant flow
57
    redirect_uri: Option<String>, // for authorization code grant flow
58
    client_id: Option<String>,
59
    scope: Option<String>, // for client credentials grant flow
60
}
61

            
62
#[derive(Deserialize, Serialize)]
63
pub struct RefreshTokenRequest {
64
    #[serde(skip)]
65
    authorization: Option<(String, Vec<u8>)>,
66
    grant_type: String,
67
    refresh_token: String,
68
    scope: Option<String>,
69
    client_id: Option<String>,
70
}
71

            
72
pub const ALLOW_VALUE: &'static str = "yes";
73

            
74
impl<S> FromRequest<S> for GetAuthRequest
75
where
76
    S: Send + Sync,
77
{
78
    type Rejection = Response;
79

            
80
36
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
81
36
        match Query::<GetAuthRequest>::from_request(req, state).await {
82
4
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
83
32
            Ok(request) => Ok(request.0),
84
        }
85
36
    }
86
}
87

            
88
impl<S> FromRequest<S> for GetLoginRequest
89
where
90
    S: Send + Sync,
91
{
92
    type Rejection = Response;
93

            
94
28
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
95
28
        match Query::<GetLoginRequest>::from_request(req, state).await {
96
4
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
97
24
            Ok(request) => Ok(request.0),
98
        }
99
28
    }
100
}
101

            
102
impl<S> FromRequest<S> for PostLoginRequest
103
where
104
    Bytes: FromRequest<S>,
105
    S: Send + Sync,
106
{
107
    type Rejection = Response;
108

            
109
596
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
110
596
        match Form::<PostLoginRequest>::from_request(req, state).await {
111
4
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
112
592
            Ok(body) => Ok(body.0),
113
        }
114
596
    }
115
}
116

            
117
impl AuthorizationRequest {
118
572
    pub fn session_id(&self) -> &str {
119
572
        self.session_id.as_str()
120
572
    }
121

            
122
580
    pub fn allowed(&self) -> Option<bool> {
123
580
        if let Some(allow_str) = self.allow.as_deref() {
124
560
            return Some(allow_str == ALLOW_VALUE);
125
20
        }
126
20
        None
127
580
    }
128
}
129

            
130
impl<S> FromRequest<S> for AuthorizationRequest
131
where
132
    Bytes: FromRequest<S>,
133
    S: Send + Sync,
134
{
135
    type Rejection = Response;
136

            
137
628
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
138
628
        match *req.method() {
139
36
            Method::GET => match Query::<AuthorizationRequest>::from_request(req, state).await {
140
4
                Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
141
32
                Ok(request) => Ok(request.0),
142
            },
143
592
            Method::POST => match Form::<AuthorizationRequest>::from_request(req, state).await {
144
4
                Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
145
588
                Ok(request) => Ok(request.0),
146
            },
147
            _ => Err(OAuth2Error::new_request(Some("invalid method".to_string())).into_response()),
148
        }
149
628
    }
150
}
151

            
152
impl OxideAuthorizationRequest for AuthorizationRequest {
153
620
    fn valid(&self) -> bool {
154
620
        true
155
620
    }
156

            
157
640
    fn client_id(&self) -> Option<Cow<str>> {
158
640
        Some(Cow::from(self.client_id.as_str()))
159
640
    }
160

            
161
616
    fn scope(&self) -> Option<Cow<str>> {
162
616
        match self.scope.as_ref() {
163
492
            None => None,
164
124
            Some(scope) => Some(Cow::from(scope)),
165
        }
166
616
    }
167

            
168
640
    fn redirect_uri(&self) -> Option<Cow<str>> {
169
640
        Some(Cow::from(&self.redirect_uri))
170
640
    }
171

            
172
624
    fn state(&self) -> Option<Cow<str>> {
173
624
        match self.state.as_ref() {
174
612
            None => None,
175
12
            Some(state) => Some(Cow::from(state)),
176
        }
177
624
    }
178

            
179
624
    fn response_type(&self) -> Option<Cow<str>> {
180
624
        Some(Cow::from(&self.response_type))
181
624
    }
182

            
183
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
184
        None
185
    }
186
}
187

            
188
impl<S> FromRequest<S> for AccessTokenRequest
189
where
190
    Bytes: FromRequest<S>,
191
    S: Send + Sync,
192
{
193
    type Rejection = Response;
194

            
195
584
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
196
584
        let authorization = match parse_basic_auth(&req) {
197
            Err(e) => return Err(e.into_response()),
198
584
            Ok(auth) => auth,
199
        };
200
584
        let mut request = match Form::<AccessTokenRequest>::from_request(req, state).await {
201
4
            Err(e) => return Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
202
580
            Ok(request) => request.0,
203
580
        };
204
580
        request.authorization = authorization;
205
580
        Ok(request)
206
584
    }
207
}
208

            
209
impl OxideAccessTokenRequest for AccessTokenRequest {
210
548
    fn valid(&self) -> bool {
211
548
        true
212
548
    }
213

            
214
524
    fn code(&self) -> Option<Cow<str>> {
215
524
        match self.code.as_ref() {
216
            None => None,
217
524
            Some(code) => Some(Cow::from(code)),
218
        }
219
524
    }
220

            
221
580
    fn authorization(&self) -> Authorization {
222
580
        match self.authorization.as_ref() {
223
492
            None => Authorization::None,
224
88
            Some(auth) => match auth.1.len() {
225
12
                0 => Authorization::Username(Cow::from(auth.0.as_str())),
226
76
                _ => Authorization::UsernamePassword(
227
76
                    Cow::from(auth.0.as_str()),
228
76
                    Cow::from(auth.1.as_slice()),
229
76
                ),
230
            },
231
        }
232
580
    }
233

            
234
548
    fn client_id(&self) -> Option<Cow<str>> {
235
548
        match self.client_id.as_ref() {
236
68
            None => None,
237
480
            Some(id) => Some(Cow::from(id)),
238
        }
239
548
    }
240

            
241
532
    fn redirect_uri(&self) -> Option<Cow<str>> {
242
532
        match self.redirect_uri.as_ref() {
243
            None => None,
244
532
            Some(uri) => Some(Cow::from(uri)),
245
        }
246
532
    }
247

            
248
1128
    fn grant_type(&self) -> Option<Cow<str>> {
249
1128
        Some(Cow::from(&self.grant_type))
250
1128
    }
251

            
252
548
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
253
548
        None
254
548
    }
255

            
256
    fn allow_credentials_in_body(&self) -> bool {
257
        false
258
    }
259
}
260

            
261
impl<S> FromRequest<S> for RefreshTokenRequest
262
where
263
    Bytes: FromRequest<S>,
264
    S: Send + Sync,
265
{
266
    type Rejection = Response;
267

            
268
72
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
269
72
        let authorization = match parse_basic_auth(&req) {
270
            Err(e) => return Err(e.into_response()),
271
72
            Ok(auth) => auth,
272
        };
273
72
        let mut request = match Form::<RefreshTokenRequest>::from_request(req, state).await {
274
4
            Err(e) => {
275
4
                return Err(OAuth2Error::new_request(Some(e.to_string())).into_response());
276
            }
277
68
            Ok(request) => request.0,
278
68
        };
279
68
        request.authorization = authorization;
280
68
        Ok(request)
281
72
    }
282
}
283

            
284
impl OxideRefreshTokenRequest for RefreshTokenRequest {
285
68
    fn valid(&self) -> bool {
286
68
        true
287
68
    }
288

            
289
68
    fn refresh_token(&self) -> Option<Cow<str>> {
290
68
        Some(Cow::from(&self.refresh_token))
291
68
    }
292

            
293
92
    fn scope(&self) -> Option<Cow<str>> {
294
92
        match self.scope.as_ref() {
295
56
            None => None,
296
36
            Some(scope) => Some(Cow::from(scope)),
297
        }
298
92
    }
299

            
300
68
    fn grant_type(&self) -> Option<Cow<str>> {
301
68
        Some(Cow::from(&self.grant_type))
302
68
    }
303

            
304
60
    fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)> {
305
60
        match self.authorization.as_ref() {
306
48
            None => None,
307
12
            Some(auth) => Some((Cow::from(auth.0.as_str()), Cow::from(auth.1.as_slice()))),
308
        }
309
60
    }
310

            
311
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
312
        None
313
    }
314
}
315

            
316
656
fn parse_basic_auth(req: &Request) -> Result<Option<(String, Vec<u8>)>, OAuth2Error> {
317
656
    let mut auth_all = req.headers().get_all(header::AUTHORIZATION).iter();
318
656
    let auth = match auth_all.next() {
319
556
        None => return Ok(None),
320
100
        Some(auth) => match auth.to_str() {
321
            Err(e) => return Err(OAuth2Error::new_request(Some(e.to_string()))),
322
100
            Ok(auth) => auth,
323
100
        },
324
100
    };
325
100
    if auth_all.next() != None {
326
        return Err(OAuth2Error::new_request(Some(
327
            "invalid multiple Authorization header".to_string(),
328
        )));
329
100
    } else if !auth.starts_with("Basic ") || auth.len() < 7 {
330
        return Err(OAuth2Error::new_request(Some(
331
            "not a Basic header".to_string(),
332
        )));
333
100
    }
334
100
    let auth = match general_purpose::STANDARD.decode(&auth[6..]) {
335
        Err(e) => match general_purpose::STANDARD_NO_PAD.decode(&auth[6..]) {
336
            Err(_) => {
337
                return Err(OAuth2Error::new_request(Some(format!(
338
                    "invalid Basic content: {}",
339
                    e
340
                ))));
341
            }
342
            Ok(auth) => auth,
343
        },
344
100
        Ok(auth) => auth,
345
    };
346
836
    let mut split = auth.splitn(2, |&c| c == b':');
347
100
    let user = match split.next() {
348
        None => {
349
            return Err(OAuth2Error::new_request(Some(
350
                "invalid Basic content".to_string(),
351
            )));
352
        }
353
100
        Some(user) => user,
354
    };
355
100
    let pass = match split.next() {
356
        None => {
357
            return Err(OAuth2Error::new_request(Some(
358
                "invalid Basic content".to_string(),
359
            )));
360
        }
361
100
        Some(pass) => pass,
362
    };
363
100
    let user = match str::from_utf8(user) {
364
        Err(e) => {
365
            return Err(OAuth2Error::new_request(Some(format!(
366
                "invalid Basic content: {}",
367
                e
368
            ))));
369
        }
370
100
        Ok(user) => user,
371
100
    };
372
100
    Ok(Some((user.to_string(), pass.to_vec())))
373
656
}