1
use axum::{
2
    extract::{
3
        rejection::JsonRejection, FromRequest, FromRequestParts, Json as AxumJson,
4
        Path as AxumPath, Query as AxumQuery, Request,
5
    },
6
    http::{header, request::Parts},
7
    response::{IntoResponse, Response},
8
};
9
use bytes::{BufMut, BytesMut};
10
use serde::{de::DeserializeOwned, Serialize};
11

            
12
use crate::{constants::ContentType, err::ErrResp};
13

            
14
/// JSON Extractor / Response.
15
///
16
/// This is the customized [`axum::extract::Json`] version to respose error with [`ErrResp`].
17
pub struct Json<T>(pub T);
18

            
19
/// Path Extractor / Response.
20
///
21
/// This is the customized [`axum::extract::Path`] version to respose error with [`ErrResp`].
22
pub struct Path<T>(pub T);
23

            
24
/// Query Extractor / Response.
25
///
26
/// This is the customized [`axum::extract::Query`] version to respose error with [`ErrResp`].
27
pub struct Query<T>(pub T);
28

            
29
impl<S, T> FromRequest<S> for Json<T>
30
where
31
    AxumJson<T>: FromRequest<S, Rejection = JsonRejection>,
32
    S: Send + Sync,
33
{
34
    type Rejection = ErrResp;
35

            
36
2
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
37
2
        match AxumJson::<T>::from_request(req, state).await {
38
1
            Err(e) => Err(ErrResp::ErrParam(Some(e.to_string()))),
39
1
            Ok(value) => Ok(Self(value.0)),
40
        }
41
2
    }
42
}
43

            
44
impl<T> IntoResponse for Json<T>
45
where
46
    T: Serialize,
47
{
48
1
    fn into_response(self) -> Response {
49
1
        // Use a small initial capacity of 128 bytes like serde_json::to_vec
50
1
        // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
51
1
        let mut buf = BytesMut::with_capacity(128).writer();
52
1
        match serde_json::to_writer(&mut buf, &self.0) {
53
            Err(e) => ErrResp::ErrUnknown(Some(e.to_string())).into_response(),
54
1
            Ok(()) => (
55
1
                [(header::CONTENT_TYPE, ContentType::JSON)],
56
1
                buf.into_inner().freeze(),
57
1
            )
58
1
                .into_response(),
59
        }
60
1
    }
61
}
62

            
63
impl<T, S> FromRequestParts<S> for Path<T>
64
where
65
    T: DeserializeOwned + Send,
66
    S: Send + Sync,
67
{
68
    type Rejection = ErrResp;
69

            
70
2
    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
71
2
        match AxumPath::from_request_parts(parts, state).await {
72
1
            Err(e) => Err(ErrResp::ErrParam(Some(e.to_string()))),
73
1
            Ok(value) => Ok(Self(value.0)),
74
        }
75
2
    }
76
}
77

            
78
impl<T, S> FromRequestParts<S> for Query<T>
79
where
80
    T: DeserializeOwned,
81
    S: Send + Sync,
82
{
83
    type Rejection = ErrResp;
84

            
85
2
    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
86
2
        match AxumQuery::from_request_parts(parts, state).await {
87
1
            Err(e) => Err(ErrResp::ErrParam(Some(e.to_string()))),
88
1
            Ok(value) => Ok(Self(value.0)),
89
        }
90
2
    }
91
}
92

            
93
/// Parse Authorization header content. Returns `None` means no Authorization header.
94
3
pub fn parse_header_auth(req: &Request) -> Result<Option<String>, ErrResp> {
95
3
    let mut auth_all = req.headers().get_all(header::AUTHORIZATION).iter();
96
3
    let auth = match auth_all.next() {
97
1
        None => return Ok(None),
98
2
        Some(auth) => match auth.to_str() {
99
            Err(e) => return Err(ErrResp::ErrParam(Some(e.to_string()))),
100
2
            Ok(auth) => auth,
101
2
        },
102
2
    };
103
2
    if auth_all.next() != None {
104
1
        return Err(ErrResp::ErrParam(Some(
105
1
            "invalid multiple Authorization header".to_string(),
106
1
        )));
107
1
    }
108
1
    Ok(Some(auth.to_string()))
109
3
}