1
use std::{borrow::Cow, str};
2

            
3
use async_trait::async_trait;
4
use axum::{
5
    body::Bytes,
6
    extract::{Form, FromRequest, Query, Request},
7
    http::{header, Method},
8
    response::{IntoResponse, Response},
9
};
10
use base64::{engine::general_purpose, Engine};
11
use oxide_auth::code_grant::{
12
    accesstoken::{Authorization, Request as OxideAccessTokenRequest},
13
    authorization::Request as OxideAuthorizationRequest,
14
    refresh::Request as OxideRefreshTokenRequest,
15
};
16
use serde::{Deserialize, Serialize};
17

            
18
use super::response::OAuth2Error;
19

            
20
1500
#[derive(Deserialize, Serialize)]
21
pub struct GetAuthRequest {
22
    pub response_type: String,
23
    pub client_id: String,
24
    pub redirect_uri: String,
25
    pub scope: Option<String>,
26
    pub state: Option<String>,
27
}
28

            
29
26
#[derive(Deserialize, Serialize)]
30
pub struct GetLoginRequest {
31
    pub state: String,
32
}
33

            
34
1184
#[derive(Deserialize)]
35
pub struct PostLoginRequest {
36
    pub account: String,
37
    pub password: String,
38
    pub state: String,
39
}
40

            
41
1894
#[derive(Deserialize, Serialize)]
42
pub struct AuthorizationRequest {
43
    response_type: String,
44
    client_id: String,
45
    redirect_uri: String,
46
    scope: Option<String>,
47
    state: Option<String>,
48
    session_id: String,
49
    allow: Option<String>,
50
}
51

            
52
1368
#[derive(Deserialize, Serialize)]
53
pub struct AccessTokenRequest {
54
    #[serde(skip)]
55
    authorization: Option<(String, Vec<u8>)>,
56
    grant_type: String,
57
    code: Option<String>,         // for authorization code grant flow
58
    redirect_uri: Option<String>, // for authorization code grant flow
59
    client_id: Option<String>,
60
    scope: Option<String>, // for client credentials grant flow
61
}
62

            
63
138
#[derive(Deserialize, Serialize)]
64
pub struct RefreshTokenRequest {
65
    #[serde(skip)]
66
    authorization: Option<(String, Vec<u8>)>,
67
    grant_type: String,
68
    refresh_token: String,
69
    scope: Option<String>,
70
    client_id: Option<String>,
71
}
72

            
73
pub const ALLOW_VALUE: &'static str = "yes";
74

            
75
#[async_trait]
76
impl<S> FromRequest<S> for GetAuthRequest
77
where
78
    S: Send + Sync,
79
{
80
    type Rejection = Response;
81

            
82
18
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
83
18
        match Query::<GetAuthRequest>::from_request(req, state).await {
84
2
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
85
16
            Ok(request) => Ok(request.0),
86
        }
87
36
    }
88
}
89

            
90
#[async_trait]
91
impl<S> FromRequest<S> for GetLoginRequest
92
where
93
    S: Send + Sync,
94
{
95
    type Rejection = Response;
96

            
97
14
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
98
14
        match Query::<GetLoginRequest>::from_request(req, state).await {
99
2
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
100
12
            Ok(request) => Ok(request.0),
101
        }
102
28
    }
103
}
104

            
105
#[async_trait]
106
impl<S> FromRequest<S> for PostLoginRequest
107
where
108
    Bytes: FromRequest<S>,
109
    S: Send + Sync,
110
{
111
    type Rejection = Response;
112

            
113
298
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
114
298
        match Form::<PostLoginRequest>::from_request(req, state).await {
115
2
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
116
296
            Ok(body) => Ok(body.0),
117
        }
118
596
    }
119
}
120

            
121
impl AuthorizationRequest {
122
286
    pub fn session_id(&self) -> &str {
123
286
        self.session_id.as_str()
124
286
    }
125

            
126
290
    pub fn allowed(&self) -> Option<bool> {
127
290
        if let Some(allow_str) = self.allow.as_deref() {
128
280
            return Some(allow_str == ALLOW_VALUE);
129
10
        }
130
10
        None
131
290
    }
132
}
133

            
134
#[async_trait]
135
impl<S> FromRequest<S> for AuthorizationRequest
136
where
137
    Bytes: FromRequest<S>,
138
    S: Send + Sync,
139
{
140
    type Rejection = Response;
141

            
142
314
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
143
314
        match *req.method() {
144
18
            Method::GET => match Query::<AuthorizationRequest>::from_request(req, state).await {
145
2
                Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
146
16
                Ok(request) => Ok(request.0),
147
            },
148
296
            Method::POST => match Form::<AuthorizationRequest>::from_request(req, state).await {
149
2
                Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
150
294
                Ok(request) => Ok(request.0),
151
            },
152
            _ => Err(OAuth2Error::new_request(Some("invalid method".to_string())).into_response()),
153
        }
154
628
    }
155
}
156

            
157
impl OxideAuthorizationRequest for AuthorizationRequest {
158
310
    fn valid(&self) -> bool {
159
310
        true
160
310
    }
161

            
162
320
    fn client_id(&self) -> Option<Cow<str>> {
163
320
        Some(Cow::from(self.client_id.as_str()))
164
320
    }
165

            
166
308
    fn scope(&self) -> Option<Cow<str>> {
167
308
        match self.scope.as_ref() {
168
246
            None => None,
169
62
            Some(scope) => Some(Cow::from(scope)),
170
        }
171
308
    }
172

            
173
320
    fn redirect_uri(&self) -> Option<Cow<str>> {
174
320
        Some(Cow::from(&self.redirect_uri))
175
320
    }
176

            
177
312
    fn state(&self) -> Option<Cow<str>> {
178
312
        match self.state.as_ref() {
179
306
            None => None,
180
6
            Some(state) => Some(Cow::from(state)),
181
        }
182
312
    }
183

            
184
312
    fn response_type(&self) -> Option<Cow<str>> {
185
312
        Some(Cow::from(&self.response_type))
186
312
    }
187

            
188
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
189
        None
190
    }
191
}
192

            
193
#[async_trait]
194
impl<S> FromRequest<S> for AccessTokenRequest
195
where
196
    Bytes: FromRequest<S>,
197
    S: Send + Sync,
198
{
199
    type Rejection = Response;
200

            
201
292
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
202
292
        let authorization = match parse_basic_auth(&req) {
203
            Err(e) => return Err(e.into_response()),
204
292
            Ok(auth) => auth,
205
        };
206
292
        let mut request = match Form::<AccessTokenRequest>::from_request(req, state).await {
207
2
            Err(e) => return Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
208
290
            Ok(request) => request.0,
209
290
        };
210
290
        request.authorization = authorization;
211
290
        Ok(request)
212
584
    }
213
}
214

            
215
impl OxideAccessTokenRequest for AccessTokenRequest {
216
274
    fn valid(&self) -> bool {
217
274
        true
218
274
    }
219

            
220
262
    fn code(&self) -> Option<Cow<str>> {
221
262
        match self.code.as_ref() {
222
            None => None,
223
262
            Some(code) => Some(Cow::from(code)),
224
        }
225
262
    }
226

            
227
290
    fn authorization(&self) -> Authorization {
228
290
        match self.authorization.as_ref() {
229
246
            None => Authorization::None,
230
44
            Some(auth) => match auth.1.len() {
231
6
                0 => Authorization::Username(Cow::from(auth.0.as_str())),
232
38
                _ => Authorization::UsernamePassword(
233
38
                    Cow::from(auth.0.as_str()),
234
38
                    Cow::from(auth.1.as_slice()),
235
38
                ),
236
            },
237
        }
238
290
    }
239

            
240
274
    fn client_id(&self) -> Option<Cow<str>> {
241
274
        match self.client_id.as_ref() {
242
34
            None => None,
243
240
            Some(id) => Some(Cow::from(id)),
244
        }
245
274
    }
246

            
247
266
    fn redirect_uri(&self) -> Option<Cow<str>> {
248
266
        match self.redirect_uri.as_ref() {
249
            None => None,
250
266
            Some(uri) => Some(Cow::from(uri)),
251
        }
252
266
    }
253

            
254
564
    fn grant_type(&self) -> Option<Cow<str>> {
255
564
        Some(Cow::from(&self.grant_type))
256
564
    }
257

            
258
274
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
259
274
        None
260
274
    }
261

            
262
    fn allow_credentials_in_body(&self) -> bool {
263
        false
264
    }
265
}
266

            
267
#[async_trait]
268
impl<S> FromRequest<S> for RefreshTokenRequest
269
where
270
    Bytes: FromRequest<S>,
271
    S: Send + Sync,
272
{
273
    type Rejection = Response;
274

            
275
36
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
276
36
        let authorization = match parse_basic_auth(&req) {
277
            Err(e) => return Err(e.into_response()),
278
36
            Ok(auth) => auth,
279
        };
280
36
        let mut request = match Form::<RefreshTokenRequest>::from_request(req, state).await {
281
2
            Err(e) => {
282
2
                return Err(OAuth2Error::new_request(Some(e.to_string())).into_response());
283
            }
284
34
            Ok(request) => request.0,
285
34
        };
286
34
        request.authorization = authorization;
287
34
        Ok(request)
288
72
    }
289
}
290

            
291
impl OxideRefreshTokenRequest for RefreshTokenRequest {
292
34
    fn valid(&self) -> bool {
293
34
        true
294
34
    }
295

            
296
34
    fn refresh_token(&self) -> Option<Cow<str>> {
297
34
        Some(Cow::from(&self.refresh_token))
298
34
    }
299

            
300
46
    fn scope(&self) -> Option<Cow<str>> {
301
46
        match self.scope.as_ref() {
302
28
            None => None,
303
18
            Some(scope) => Some(Cow::from(scope)),
304
        }
305
46
    }
306

            
307
34
    fn grant_type(&self) -> Option<Cow<str>> {
308
34
        Some(Cow::from(&self.grant_type))
309
34
    }
310

            
311
30
    fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)> {
312
30
        match self.authorization.as_ref() {
313
24
            None => None,
314
6
            Some(auth) => Some((Cow::from(auth.0.as_str()), Cow::from(auth.1.as_slice()))),
315
        }
316
30
    }
317

            
318
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
319
        None
320
    }
321
}
322

            
323
328
fn parse_basic_auth(req: &Request) -> Result<Option<(String, Vec<u8>)>, OAuth2Error> {
324
328
    let mut auth_all = req.headers().get_all(header::AUTHORIZATION).iter();
325
328
    let auth = match auth_all.next() {
326
278
        None => return Ok(None),
327
50
        Some(auth) => match auth.to_str() {
328
            Err(e) => return Err(OAuth2Error::new_request(Some(e.to_string()))),
329
50
            Ok(auth) => auth,
330
50
        },
331
50
    };
332
50
    if auth_all.next() != None {
333
        return Err(OAuth2Error::new_request(Some(
334
            "invalid multiple Authorization header".to_string(),
335
        )));
336
50
    } else if !auth.starts_with("Basic ") || auth.len() < 7 {
337
        return Err(OAuth2Error::new_request(Some(
338
            "not a Basic header".to_string(),
339
        )));
340
50
    }
341
50
    let auth = match general_purpose::STANDARD.decode(&auth[6..]) {
342
        Err(e) => match general_purpose::STANDARD_NO_PAD.decode(&auth[6..]) {
343
            Err(_) => {
344
                return Err(OAuth2Error::new_request(Some(format!(
345
                    "invalid Basic content: {}",
346
                    e
347
                ))))
348
            }
349
            Ok(auth) => auth,
350
        },
351
50
        Ok(auth) => auth,
352
    };
353
418
    let mut split = auth.splitn(2, |&c| c == b':');
354
50
    let user = match split.next() {
355
        None => {
356
            return Err(OAuth2Error::new_request(Some(
357
                "invalid Basic content".to_string(),
358
            )))
359
        }
360
50
        Some(user) => user,
361
    };
362
50
    let pass = match split.next() {
363
        None => {
364
            return Err(OAuth2Error::new_request(Some(
365
                "invalid Basic content".to_string(),
366
            )))
367
        }
368
50
        Some(pass) => pass,
369
    };
370
50
    let user = match str::from_utf8(user) {
371
        Err(e) => {
372
            return Err(OAuth2Error::new_request(Some(format!(
373
                "invalid Basic content: {}",
374
                e
375
            ))))
376
        }
377
50
        Ok(user) => user,
378
50
    };
379
50
    Ok(Some((user.to_string(), pass.to_vec())))
380
328
}