1
use std::{borrow::Cow, str};
2

            
3
use axum::{
4
    body::Bytes,
5
    extract::{Form, FromRequest, Query, Request},
6
    http::{header, Method},
7
    response::{IntoResponse, Response},
8
};
9
use base64::{engine::general_purpose, Engine};
10
use oxide_auth::code_grant::{
11
    accesstoken::{Authorization, Request as OxideAccessTokenRequest},
12
    authorization::Request as OxideAuthorizationRequest,
13
    refresh::Request as OxideRefreshTokenRequest,
14
};
15
use serde::{Deserialize, Serialize};
16

            
17
use super::response::OAuth2Error;
18

            
19
#[derive(Deserialize, Serialize)]
20
pub struct GetAuthRequest {
21
    pub response_type: String,
22
    pub client_id: String,
23
    pub redirect_uri: String,
24
    pub scope: Option<String>,
25
    pub state: Option<String>,
26
}
27

            
28
#[derive(Deserialize, Serialize)]
29
pub struct GetLoginRequest {
30
    pub state: String,
31
}
32

            
33
#[derive(Deserialize)]
34
pub struct PostLoginRequest {
35
    pub account: String,
36
    pub password: String,
37
    pub state: String,
38
}
39

            
40
#[derive(Deserialize, Serialize)]
41
pub struct AuthorizationRequest {
42
    response_type: String,
43
    client_id: String,
44
    redirect_uri: String,
45
    scope: Option<String>,
46
    state: Option<String>,
47
    session_id: String,
48
    allow: Option<String>,
49
}
50

            
51
#[derive(Deserialize, Serialize)]
52
pub struct AccessTokenRequest {
53
    #[serde(skip)]
54
    authorization: Option<(String, Vec<u8>)>,
55
    grant_type: String,
56
    code: Option<String>,         // for authorization code grant flow
57
    redirect_uri: Option<String>, // for authorization code grant flow
58
    client_id: Option<String>,
59
    scope: Option<String>, // for client credentials grant flow
60
}
61

            
62
#[derive(Deserialize, Serialize)]
63
pub struct RefreshTokenRequest {
64
    #[serde(skip)]
65
    authorization: Option<(String, Vec<u8>)>,
66
    grant_type: String,
67
    refresh_token: String,
68
    scope: Option<String>,
69
    client_id: Option<String>,
70
}
71

            
72
pub const ALLOW_VALUE: &'static str = "yes";
73

            
74
impl<S> FromRequest<S> for GetAuthRequest
75
where
76
    S: Send + Sync,
77
{
78
    type Rejection = Response;
79

            
80
18
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
81
18
        match Query::<GetAuthRequest>::from_request(req, state).await {
82
2
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
83
16
            Ok(request) => Ok(request.0),
84
        }
85
18
    }
86
}
87

            
88
impl<S> FromRequest<S> for GetLoginRequest
89
where
90
    S: Send + Sync,
91
{
92
    type Rejection = Response;
93

            
94
14
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
95
14
        match Query::<GetLoginRequest>::from_request(req, state).await {
96
2
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
97
12
            Ok(request) => Ok(request.0),
98
        }
99
14
    }
100
}
101

            
102
impl<S> FromRequest<S> for PostLoginRequest
103
where
104
    Bytes: FromRequest<S>,
105
    S: Send + Sync,
106
{
107
    type Rejection = Response;
108

            
109
298
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
110
298
        match Form::<PostLoginRequest>::from_request(req, state).await {
111
2
            Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
112
296
            Ok(body) => Ok(body.0),
113
        }
114
298
    }
115
}
116

            
117
impl AuthorizationRequest {
118
286
    pub fn session_id(&self) -> &str {
119
286
        self.session_id.as_str()
120
286
    }
121

            
122
290
    pub fn allowed(&self) -> Option<bool> {
123
290
        if let Some(allow_str) = self.allow.as_deref() {
124
280
            return Some(allow_str == ALLOW_VALUE);
125
10
        }
126
10
        None
127
290
    }
128
}
129

            
130
impl<S> FromRequest<S> for AuthorizationRequest
131
where
132
    Bytes: FromRequest<S>,
133
    S: Send + Sync,
134
{
135
    type Rejection = Response;
136

            
137
314
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
138
314
        match *req.method() {
139
18
            Method::GET => match Query::<AuthorizationRequest>::from_request(req, state).await {
140
2
                Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
141
16
                Ok(request) => Ok(request.0),
142
            },
143
296
            Method::POST => match Form::<AuthorizationRequest>::from_request(req, state).await {
144
2
                Err(e) => Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
145
294
                Ok(request) => Ok(request.0),
146
            },
147
            _ => Err(OAuth2Error::new_request(Some("invalid method".to_string())).into_response()),
148
        }
149
314
    }
150
}
151

            
152
impl OxideAuthorizationRequest for AuthorizationRequest {
153
310
    fn valid(&self) -> bool {
154
310
        true
155
310
    }
156

            
157
320
    fn client_id(&self) -> Option<Cow<str>> {
158
320
        Some(Cow::from(self.client_id.as_str()))
159
320
    }
160

            
161
308
    fn scope(&self) -> Option<Cow<str>> {
162
308
        match self.scope.as_ref() {
163
246
            None => None,
164
62
            Some(scope) => Some(Cow::from(scope)),
165
        }
166
308
    }
167

            
168
320
    fn redirect_uri(&self) -> Option<Cow<str>> {
169
320
        Some(Cow::from(&self.redirect_uri))
170
320
    }
171

            
172
312
    fn state(&self) -> Option<Cow<str>> {
173
312
        match self.state.as_ref() {
174
306
            None => None,
175
6
            Some(state) => Some(Cow::from(state)),
176
        }
177
312
    }
178

            
179
312
    fn response_type(&self) -> Option<Cow<str>> {
180
312
        Some(Cow::from(&self.response_type))
181
312
    }
182

            
183
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
184
        None
185
    }
186
}
187

            
188
impl<S> FromRequest<S> for AccessTokenRequest
189
where
190
    Bytes: FromRequest<S>,
191
    S: Send + Sync,
192
{
193
    type Rejection = Response;
194

            
195
292
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
196
292
        let authorization = match parse_basic_auth(&req) {
197
            Err(e) => return Err(e.into_response()),
198
292
            Ok(auth) => auth,
199
        };
200
292
        let mut request = match Form::<AccessTokenRequest>::from_request(req, state).await {
201
2
            Err(e) => return Err(OAuth2Error::new_request(Some(e.to_string())).into_response()),
202
290
            Ok(request) => request.0,
203
290
        };
204
290
        request.authorization = authorization;
205
290
        Ok(request)
206
292
    }
207
}
208

            
209
impl OxideAccessTokenRequest for AccessTokenRequest {
210
274
    fn valid(&self) -> bool {
211
274
        true
212
274
    }
213

            
214
262
    fn code(&self) -> Option<Cow<str>> {
215
262
        match self.code.as_ref() {
216
            None => None,
217
262
            Some(code) => Some(Cow::from(code)),
218
        }
219
262
    }
220

            
221
290
    fn authorization(&self) -> Authorization {
222
290
        match self.authorization.as_ref() {
223
246
            None => Authorization::None,
224
44
            Some(auth) => match auth.1.len() {
225
6
                0 => Authorization::Username(Cow::from(auth.0.as_str())),
226
38
                _ => Authorization::UsernamePassword(
227
38
                    Cow::from(auth.0.as_str()),
228
38
                    Cow::from(auth.1.as_slice()),
229
38
                ),
230
            },
231
        }
232
290
    }
233

            
234
274
    fn client_id(&self) -> Option<Cow<str>> {
235
274
        match self.client_id.as_ref() {
236
34
            None => None,
237
240
            Some(id) => Some(Cow::from(id)),
238
        }
239
274
    }
240

            
241
266
    fn redirect_uri(&self) -> Option<Cow<str>> {
242
266
        match self.redirect_uri.as_ref() {
243
            None => None,
244
266
            Some(uri) => Some(Cow::from(uri)),
245
        }
246
266
    }
247

            
248
564
    fn grant_type(&self) -> Option<Cow<str>> {
249
564
        Some(Cow::from(&self.grant_type))
250
564
    }
251

            
252
274
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
253
274
        None
254
274
    }
255

            
256
    fn allow_credentials_in_body(&self) -> bool {
257
        false
258
    }
259
}
260

            
261
impl<S> FromRequest<S> for RefreshTokenRequest
262
where
263
    Bytes: FromRequest<S>,
264
    S: Send + Sync,
265
{
266
    type Rejection = Response;
267

            
268
36
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
269
36
        let authorization = match parse_basic_auth(&req) {
270
            Err(e) => return Err(e.into_response()),
271
36
            Ok(auth) => auth,
272
        };
273
36
        let mut request = match Form::<RefreshTokenRequest>::from_request(req, state).await {
274
2
            Err(e) => {
275
2
                return Err(OAuth2Error::new_request(Some(e.to_string())).into_response());
276
            }
277
34
            Ok(request) => request.0,
278
34
        };
279
34
        request.authorization = authorization;
280
34
        Ok(request)
281
36
    }
282
}
283

            
284
impl OxideRefreshTokenRequest for RefreshTokenRequest {
285
34
    fn valid(&self) -> bool {
286
34
        true
287
34
    }
288

            
289
34
    fn refresh_token(&self) -> Option<Cow<str>> {
290
34
        Some(Cow::from(&self.refresh_token))
291
34
    }
292

            
293
46
    fn scope(&self) -> Option<Cow<str>> {
294
46
        match self.scope.as_ref() {
295
28
            None => None,
296
18
            Some(scope) => Some(Cow::from(scope)),
297
        }
298
46
    }
299

            
300
34
    fn grant_type(&self) -> Option<Cow<str>> {
301
34
        Some(Cow::from(&self.grant_type))
302
34
    }
303

            
304
30
    fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)> {
305
30
        match self.authorization.as_ref() {
306
24
            None => None,
307
6
            Some(auth) => Some((Cow::from(auth.0.as_str()), Cow::from(auth.1.as_slice()))),
308
        }
309
30
    }
310

            
311
    fn extension(&self, _key: &str) -> Option<Cow<str>> {
312
        None
313
    }
314
}
315

            
316
328
fn parse_basic_auth(req: &Request) -> Result<Option<(String, Vec<u8>)>, OAuth2Error> {
317
328
    let mut auth_all = req.headers().get_all(header::AUTHORIZATION).iter();
318
328
    let auth = match auth_all.next() {
319
278
        None => return Ok(None),
320
50
        Some(auth) => match auth.to_str() {
321
            Err(e) => return Err(OAuth2Error::new_request(Some(e.to_string()))),
322
50
            Ok(auth) => auth,
323
50
        },
324
50
    };
325
50
    if auth_all.next() != None {
326
        return Err(OAuth2Error::new_request(Some(
327
            "invalid multiple Authorization header".to_string(),
328
        )));
329
50
    } else if !auth.starts_with("Basic ") || auth.len() < 7 {
330
        return Err(OAuth2Error::new_request(Some(
331
            "not a Basic header".to_string(),
332
        )));
333
50
    }
334
50
    let auth = match general_purpose::STANDARD.decode(&auth[6..]) {
335
        Err(e) => match general_purpose::STANDARD_NO_PAD.decode(&auth[6..]) {
336
            Err(_) => {
337
                return Err(OAuth2Error::new_request(Some(format!(
338
                    "invalid Basic content: {}",
339
                    e
340
                ))))
341
            }
342
            Ok(auth) => auth,
343
        },
344
50
        Ok(auth) => auth,
345
    };
346
418
    let mut split = auth.splitn(2, |&c| c == b':');
347
50
    let user = match split.next() {
348
        None => {
349
            return Err(OAuth2Error::new_request(Some(
350
                "invalid Basic content".to_string(),
351
            )))
352
        }
353
50
        Some(user) => user,
354
    };
355
50
    let pass = match split.next() {
356
        None => {
357
            return Err(OAuth2Error::new_request(Some(
358
                "invalid Basic content".to_string(),
359
            )))
360
        }
361
50
        Some(pass) => pass,
362
    };
363
50
    let user = match str::from_utf8(user) {
364
        Err(e) => {
365
            return Err(OAuth2Error::new_request(Some(format!(
366
                "invalid Basic content: {}",
367
                e
368
            ))))
369
        }
370
50
        Ok(user) => user,
371
50
    };
372
50
    Ok(Some((user.to_string(), pass.to_vec())))
373
328
}