1
//! Provides the authentication middleware by sending the Bearer token to `sylvia-iot-auth`.
2
//!
3
//! Here is an example to wrap the auth middleware and how to get token information:
4
//!
5
//! ```rust
6
//! use axum::{
7
//!     extract::Request,
8
//!     http::{header, StatusCode},
9
//!     response::{IntoResponse, Response},
10
//!     routing, Extension, Router,
11
//! };
12
//! use sylvia_iot_sdk::middlewares::auth::{AuthService, GetTokenInfoData};
13
//!
14
//! fn new_service() -> Router {
15
//!     let auth_uri = "http://localhost:1080/auth/api/v1/auth/tokeninfo";
16
//!     Router::new()
17
//!         .route("/api", routing::get(api))
18
//!         .layer(AuthService::new(auth_uri.clone()))
19
//! }
20
//!
21
//! async fn api(Extension(token_info): Extension<GetTokenInfoData>) -> impl IntoResponse {
22
//!     StatusCode::NO_CONTENT
23
//! }
24
//! ```
25

            
26
use axum::{
27
    extract::Request,
28
    http::header,
29
    response::{IntoResponse, Response},
30
};
31
use futures::future::BoxFuture;
32
use reqwest;
33
use serde::{self, Deserialize};
34
use std::{
35
    collections::HashMap,
36
    task::{Context, Poll},
37
};
38
use tower::{Layer, Service};
39

            
40
use crate::util::err::ErrResp;
41

            
42
#[derive(Clone)]
43
pub struct GetTokenInfoData {
44
    /// The access token.
45
    pub token: String,
46
    pub user_id: String,
47
    pub account: String,
48
    pub roles: HashMap<String, bool>,
49
    pub name: String,
50
    pub client_id: String,
51
    pub scopes: Vec<String>,
52
}
53

            
54
#[derive(Clone)]
55
pub struct AuthService {
56
    auth_uri: String,
57
}
58

            
59
#[derive(Clone)]
60
pub struct AuthMiddleware<S> {
61
    client: reqwest::Client,
62
    auth_uri: String,
63
    service: S,
64
}
65

            
66
/// The user/client information of the token.
67
#[derive(Clone, Deserialize)]
68
struct GetTokenInfo {
69
    data: GetTokenInfoDataInner,
70
}
71

            
72
#[derive(Clone, Deserialize)]
73
struct GetTokenInfoDataInner {
74
    #[serde(rename = "userId")]
75
    user_id: String,
76
    account: String,
77
    roles: HashMap<String, bool>,
78
    name: String,
79
    #[serde(rename = "clientId")]
80
    client_id: String,
81
    scopes: Vec<String>,
82
}
83

            
84
impl AuthService {
85
4
    pub fn new(auth_uri: String) -> Self {
86
4
        AuthService { auth_uri }
87
4
    }
88
}
89

            
90
impl<S> Layer<S> for AuthService {
91
    type Service = AuthMiddleware<S>;
92

            
93
20
    fn layer(&self, inner: S) -> Self::Service {
94
20
        AuthMiddleware {
95
20
            client: reqwest::Client::new(),
96
20
            auth_uri: self.auth_uri.clone(),
97
20
            service: inner,
98
20
        }
99
20
    }
100
}
101

            
102
impl<S> Service<Request> for AuthMiddleware<S>
103
where
104
    S: Service<Request, Response = Response> + Clone + Send + 'static,
105
    S::Future: Send + 'static,
106
{
107
    type Response = S::Response;
108
    type Error = S::Error;
109
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
110

            
111
4
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112
4
        self.service.poll_ready(cx)
113
4
    }
114

            
115
4
    fn call(&mut self, mut req: Request) -> Self::Future {
116
4
        let mut svc = self.service.clone();
117
4
        let client = self.client.clone();
118
4
        let auth_uri = self.auth_uri.clone();
119
4

            
120
4
        Box::pin(async move {
121
4
            let token = match parse_header_auth(&req) {
122
                Err(e) => return Ok(e.into_response()),
123
4
                Ok(token) => match token {
124
                    None => {
125
1
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
126
1
                        return Ok(e.into_response());
127
                    }
128
3
                    Some(token) => token,
129
                },
130
            };
131

            
132
3
            let token_req = match client
133
3
                .request(reqwest::Method::GET, auth_uri.as_str())
134
3
                .header(reqwest::header::AUTHORIZATION, token.as_str())
135
3
                .build()
136
            {
137
                Err(e) => {
138
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
139
                    return Ok(e.into_response());
140
                }
141
3
                Ok(req) => req,
142
            };
143
3
            let resp = match client.execute(token_req).await {
144
1
                Err(e) => {
145
1
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
146
1
                    return Ok(e.into_response());
147
                }
148
2
                Ok(resp) => match resp.status() {
149
                    reqwest::StatusCode::UNAUTHORIZED => {
150
1
                        return Ok(ErrResp::ErrAuth(None).into_response())
151
                    }
152
1
                    reqwest::StatusCode::OK => resp,
153
                    _ => {
154
                        let e = ErrResp::ErrIntMsg(Some(format!(
155
                            "auth error with status code: {}",
156
                            resp.status()
157
                        )));
158
                        return Ok(e.into_response());
159
                    }
160
                },
161
            };
162
1
            let token_info = match resp.json::<GetTokenInfo>().await {
163
                Err(e) => {
164
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
165
                    return Ok(e.into_response());
166
                }
167
1
                Ok(info) => info,
168
1
            };
169
1

            
170
1
            let mut split = token.split_whitespace();
171
1
            split.next(); // skip "Bearer".
172
1
            let token = match split.next() {
173
                None => {
174
                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
175
                    return Ok(e.into_response());
176
                }
177
1
                Some(token) => token.to_string(),
178
1
            };
179
1

            
180
1
            req.extensions_mut().insert(GetTokenInfoData {
181
1
                token,
182
1
                user_id: token_info.data.user_id,
183
1
                account: token_info.data.account,
184
1
                roles: token_info.data.roles,
185
1
                name: token_info.data.name,
186
1
                client_id: token_info.data.client_id,
187
1
                scopes: token_info.data.scopes,
188
1
            });
189

            
190
1
            let res = svc.call(req).await?;
191
1
            Ok(res)
192
4
        })
193
4
    }
194
}
195

            
196
/// Parse Authorization header content. Returns `None` means no Authorization header.
197
4
pub fn parse_header_auth(req: &Request) -> Result<Option<String>, ErrResp> {
198
4
    let mut auth_all = req.headers().get_all(header::AUTHORIZATION).iter();
199
4
    let auth = match auth_all.next() {
200
1
        None => return Ok(None),
201
3
        Some(auth) => match auth.to_str() {
202
            Err(e) => return Err(ErrResp::ErrParam(Some(e.to_string()))),
203
3
            Ok(auth) => auth,
204
3
        },
205
3
    };
206
3
    if auth_all.next() != None {
207
        return Err(ErrResp::ErrParam(Some(
208
            "invalid multiple Authorization header".to_string(),
209
        )));
210
3
    }
211
3
    Ok(Some(auth.to_string()))
212
4
}