1
use async_trait::async_trait;
2
use axum::{
3
    extract::{
4
        rejection::JsonRejection, FromRequest, FromRequestParts, Json as AxumJson,
5
        Path as AxumPath, Query as AxumQuery, Request,
6
    },
7
    http::{header, request::Parts},
8
    response::{IntoResponse, Response},
9
};
10
use bytes::{BufMut, BytesMut};
11
use serde::{de::DeserializeOwned, Serialize};
12

            
13
use crate::{constants::ContentType, err::ErrResp};
14

            
15
/// JSON Extractor / Response.
16
///
17
/// This is the customized [`axum::extract::Json`] version to respose error with [`ErrResp`].
18
pub struct Json<T>(pub T);
19

            
20
/// Path Extractor / Response.
21
///
22
/// This is the customized [`axum::extract::Path`] version to respose error with [`ErrResp`].
23
pub struct Path<T>(pub T);
24

            
25
/// Query Extractor / Response.
26
///
27
/// This is the customized [`axum::extract::Query`] version to respose error with [`ErrResp`].
28
pub struct Query<T>(pub T);
29

            
30
#[async_trait]
31
impl<S, T> FromRequest<S> for Json<T>
32
where
33
    AxumJson<T>: FromRequest<S, Rejection = JsonRejection>,
34
    S: Send + Sync,
35
{
36
    type Rejection = ErrResp;
37

            
38
2
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
39
2
        match AxumJson::<T>::from_request(req, state).await {
40
1
            Err(e) => Err(ErrResp::ErrParam(Some(e.to_string()))),
41
1
            Ok(value) => Ok(Self(value.0)),
42
        }
43
4
    }
44
}
45

            
46
impl<T> IntoResponse for Json<T>
47
where
48
    T: Serialize,
49
{
50
1
    fn into_response(self) -> Response {
51
1
        // Use a small initial capacity of 128 bytes like serde_json::to_vec
52
1
        // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
53
1
        let mut buf = BytesMut::with_capacity(128).writer();
54
1
        match serde_json::to_writer(&mut buf, &self.0) {
55
            Err(e) => ErrResp::ErrUnknown(Some(e.to_string())).into_response(),
56
1
            Ok(()) => (
57
1
                [(header::CONTENT_TYPE, ContentType::JSON)],
58
1
                buf.into_inner().freeze(),
59
1
            )
60
1
                .into_response(),
61
        }
62
1
    }
63
}
64

            
65
#[async_trait]
66
impl<T, S> FromRequestParts<S> for Path<T>
67
where
68
    T: DeserializeOwned + Send,
69
    S: Send + Sync,
70
{
71
    type Rejection = ErrResp;
72

            
73
2
    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
74
2
        match AxumPath::from_request_parts(parts, state).await {
75
1
            Err(e) => Err(ErrResp::ErrParam(Some(e.to_string()))),
76
1
            Ok(value) => Ok(Self(value.0)),
77
        }
78
4
    }
79
}
80

            
81
#[async_trait]
82
impl<T, S> FromRequestParts<S> for Query<T>
83
where
84
    T: DeserializeOwned,
85
    S: Send + Sync,
86
{
87
    type Rejection = ErrResp;
88

            
89
2
    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
90
2
        match AxumQuery::from_request_parts(parts, state).await {
91
1
            Err(e) => Err(ErrResp::ErrParam(Some(e.to_string()))),
92
1
            Ok(value) => Ok(Self(value.0)),
93
        }
94
4
    }
95
}
96

            
97
/// Parse Authorization header content. Returns `None` means no Authorization header.
98
3
pub fn parse_header_auth(req: &Request) -> Result<Option<String>, ErrResp> {
99
3
    let mut auth_all = req.headers().get_all(header::AUTHORIZATION).iter();
100
3
    let auth = match auth_all.next() {
101
1
        None => return Ok(None),
102
2
        Some(auth) => match auth.to_str() {
103
            Err(e) => return Err(ErrResp::ErrParam(Some(e.to_string()))),
104
2
            Ok(auth) => auth,
105
2
        },
106
2
    };
107
2
    if auth_all.next() != None {
108
1
        return Err(ErrResp::ErrParam(Some(
109
1
            "invalid multiple Authorization header".to_string(),
110
1
        )));
111
1
    }
112
1
    Ok(Some(auth.to_string()))
113
3
}