1
//! Provides the authentication middleware by sending the Bearer token to `sylvia-iot-auth`.
2
//!
3
//! Here is an example to wrap the auth middleware and how to get token information:
4
//!
5
//! ```rust
6
//! use axum::{
7
//!     extract::Request,
8
//!     http::{header, StatusCode},
9
//!     response::{IntoResponse, Response},
10
//!     routing, Extension, Router,
11
//! };
12
//! use reqwest;
13
//! use sylvia_iot_sdk::middlewares::auth::{AuthService, GetTokenInfoData};
14
//!
15
//! fn new_service() -> Router {
16
//!     let auth_uri = "http://localhost:1080/auth/api/v1/auth/tokeninfo";
17
//!     let client = reqwest::Client::new();
18
//!     Router::new()
19
//!         .route("/api", routing::get(api))
20
//!         .layer(AuthService::new(client, auth_uri.clone()))
21
//! }
22
//!
23
//! async fn api(Extension(token_info): Extension<GetTokenInfoData>) -> impl IntoResponse {
24
//!     StatusCode::NO_CONTENT
25
//! }
26
//! ```
27

            
28
use axum::{
29
    extract::Request,
30
    http::header,
31
    response::{IntoResponse, Response},
32
};
33
use futures::future::BoxFuture;
34
use reqwest;
35
use serde::{self, Deserialize};
36
use std::{
37
    collections::HashMap,
38
    task::{Context, Poll},
39
};
40
use tower::{Layer, Service};
41

            
42
use crate::util::err::ErrResp;
43

            
44
#[derive(Clone)]
45
pub struct GetTokenInfoData {
46
    /// The access token.
47
    pub token: String,
48
    pub user_id: String,
49
    pub account: String,
50
    pub roles: HashMap<String, bool>,
51
    pub name: String,
52
    pub client_id: String,
53
    pub scopes: Vec<String>,
54
}
55

            
56
#[derive(Clone)]
57
pub struct AuthService {
58
    client: reqwest::Client,
59
    auth_uri: String,
60
}
61

            
62
#[derive(Clone)]
63
pub struct AuthMiddleware<S> {
64
    client: reqwest::Client,
65
    auth_uri: String,
66
    service: S,
67
}
68

            
69
/// The user/client information of the token.
70
#[derive(Clone, Deserialize)]
71
struct GetTokenInfo {
72
    data: GetTokenInfoDataInner,
73
}
74

            
75
#[derive(Clone, Deserialize)]
76
struct GetTokenInfoDataInner {
77
    #[serde(rename = "userId")]
78
    user_id: String,
79
    account: String,
80
    roles: HashMap<String, bool>,
81
    name: String,
82
    #[serde(rename = "clientId")]
83
    client_id: String,
84
    scopes: Vec<String>,
85
}
86

            
87
impl AuthService {
88
8
    pub fn new(client: reqwest::Client, auth_uri: String) -> Self {
89
8
        AuthService { client, auth_uri }
90
8
    }
91
}
92

            
93
impl<S> Layer<S> for AuthService {
94
    type Service = AuthMiddleware<S>;
95

            
96
40
    fn layer(&self, inner: S) -> Self::Service {
97
40
        AuthMiddleware {
98
40
            client: self.client.clone(),
99
40
            auth_uri: self.auth_uri.clone(),
100
40
            service: inner,
101
40
        }
102
40
    }
103
}
104

            
105
impl<S> Service<Request> for AuthMiddleware<S>
106
where
107
    S: Service<Request, Response = Response> + Clone + Send + 'static,
108
    S::Future: Send + 'static,
109
{
110
    type Response = S::Response;
111
    type Error = S::Error;
112
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
113

            
114
8
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115
8
        self.service.poll_ready(cx)
116
8
    }
117

            
118
8
    fn call(&mut self, mut req: Request) -> Self::Future {
119
8
        let mut svc = self.service.clone();
120
8
        let client = self.client.clone();
121
8
        let auth_uri = self.auth_uri.clone();
122

            
123
8
        Box::pin(async move {
124
8
            let token = match parse_header_auth(&req) {
125
                Err(e) => return Ok(e.into_response()),
126
8
                Ok(token) => match token {
127
                    None => {
128
2
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
129
2
                        return Ok(e.into_response());
130
                    }
131
6
                    Some(token) => token,
132
                },
133
            };
134

            
135
6
            let token_req = match client
136
6
                .request(reqwest::Method::GET, auth_uri.as_str())
137
6
                .header(reqwest::header::AUTHORIZATION, token.as_str())
138
6
                .build()
139
            {
140
                Err(e) => {
141
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
142
                    return Ok(e.into_response());
143
                }
144
6
                Ok(req) => req,
145
            };
146
6
            let resp = match client.execute(token_req).await {
147
2
                Err(e) => {
148
2
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
149
2
                    return Ok(e.into_response());
150
                }
151
4
                Ok(resp) => match resp.status() {
152
                    reqwest::StatusCode::UNAUTHORIZED => {
153
2
                        return Ok(ErrResp::ErrAuth(None).into_response());
154
                    }
155
2
                    reqwest::StatusCode::OK => resp,
156
                    _ => {
157
                        let e = ErrResp::ErrIntMsg(Some(format!(
158
                            "auth error with status code: {}",
159
                            resp.status()
160
                        )));
161
                        return Ok(e.into_response());
162
                    }
163
                },
164
            };
165
2
            let token_info = match resp.json::<GetTokenInfo>().await {
166
                Err(e) => {
167
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
168
                    return Ok(e.into_response());
169
                }
170
2
                Ok(info) => info,
171
            };
172

            
173
2
            let mut split = token.split_whitespace();
174
2
            split.next(); // skip "Bearer".
175
2
            let token = match split.next() {
176
                None => {
177
                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
178
                    return Ok(e.into_response());
179
                }
180
2
                Some(token) => token.to_string(),
181
            };
182

            
183
2
            req.extensions_mut().insert(GetTokenInfoData {
184
2
                token,
185
2
                user_id: token_info.data.user_id,
186
2
                account: token_info.data.account,
187
2
                roles: token_info.data.roles,
188
2
                name: token_info.data.name,
189
2
                client_id: token_info.data.client_id,
190
2
                scopes: token_info.data.scopes,
191
2
            });
192

            
193
2
            let res = svc.call(req).await?;
194
2
            Ok(res)
195
8
        })
196
8
    }
197
}
198

            
199
/// Parse Authorization header content. Returns `None` means no Authorization header.
200
8
pub fn parse_header_auth(req: &Request) -> Result<Option<String>, ErrResp> {
201
8
    let mut auth_all = req.headers().get_all(header::AUTHORIZATION).iter();
202
8
    let auth = match auth_all.next() {
203
2
        None => return Ok(None),
204
6
        Some(auth) => match auth.to_str() {
205
            Err(e) => return Err(ErrResp::ErrParam(Some(e.to_string()))),
206
6
            Ok(auth) => auth,
207
        },
208
    };
209
6
    if auth_all.next() != None {
210
        return Err(ErrResp::ErrParam(Some(
211
            "invalid multiple Authorization header".to_string(),
212
        )));
213
6
    }
214
6
    Ok(Some(auth.to_string()))
215
8
}