1
//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2

            
3
use std::{
4
    collections::{HashMap, HashSet},
5
    task::{Context, Poll},
6
};
7

            
8
use axum::{
9
    extract::Request,
10
    http::Method,
11
    response::{IntoResponse, Response},
12
};
13
use futures::future::BoxFuture;
14
use reqwest;
15
use serde::{self, Deserialize};
16
use tower::{Layer, Service};
17

            
18
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
19

            
20
pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
21
type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
22

            
23
#[derive(Clone)]
24
pub struct GetTokenInfoData {
25
    /// The access token.
26
    pub token: String,
27
    pub user_id: String,
28
    pub account: String,
29
    pub roles: HashMap<String, bool>,
30
    pub name: String,
31
    pub client_id: String,
32
    pub scopes: Vec<String>,
33
}
34

            
35
#[derive(Clone)]
36
pub struct AuthService {
37
    auth_uri: String,
38
    role_scopes: HashMap<Method, RoleScopeType>,
39
}
40

            
41
#[derive(Clone)]
42
pub struct AuthMiddleware<S> {
43
    client: reqwest::Client,
44
    auth_uri: String,
45
    role_scopes: HashMap<Method, RoleScopeInner>,
46
    service: S,
47
}
48

            
49
/// The user/client information of the token.
50
6826
#[derive(Deserialize)]
51
struct GetTokenInfo {
52
    data: GetTokenInfoDataInner,
53
}
54

            
55
23891
#[derive(Deserialize)]
56
struct GetTokenInfoDataInner {
57
    #[serde(rename = "userId")]
58
    user_id: String,
59
    account: String,
60
    roles: HashMap<String, bool>,
61
    name: String,
62
    #[serde(rename = "clientId")]
63
    client_id: String,
64
    scopes: Vec<String>,
65
}
66

            
67
impl AuthService {
68
120286
    pub fn new(auth_uri: String, role_scopes: HashMap<Method, RoleScopeType>) -> Self {
69
120286
        AuthService {
70
120286
            role_scopes,
71
120286
            auth_uri,
72
120286
        }
73
120286
    }
74
}
75

            
76
impl<S> Layer<S> for AuthService {
77
    type Service = AuthMiddleware<S>;
78

            
79
267334
    fn layer(&self, inner: S) -> Self::Service {
80
267334
        let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
81
414326
        for (k, (r, s)) in self.role_scopes.iter() {
82
414326
            role_scopes.insert(
83
414326
                k.clone(),
84
414326
                (
85
414326
                    r.iter().map(|&r| r).collect(),
86
414326
                    s.iter().map(|s| s.clone()).collect(),
87
414326
                ),
88
414326
            );
89
414326
        }
90

            
91
267334
        AuthMiddleware {
92
267334
            client: reqwest::Client::new(),
93
267334
            auth_uri: self.auth_uri.clone(),
94
267334
            role_scopes,
95
267334
            service: inner,
96
267334
        }
97
267334
    }
98
}
99

            
100
impl<S> Service<Request> for AuthMiddleware<S>
101
where
102
    S: Service<Request, Response = Response> + Clone + Send + 'static,
103
    S::Future: Send + 'static,
104
{
105
    type Response = S::Response;
106
    type Error = S::Error;
107
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
108

            
109
3553
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110
3553
        self.service.poll_ready(cx)
111
3553
    }
112

            
113
3553
    fn call(&mut self, mut req: Request) -> Self::Future {
114
3553
        let mut svc = self.service.clone();
115
3553
        let client = self.client.clone();
116
3553
        let auth_uri = self.auth_uri.clone();
117
3553
        let role_scopes = self.role_scopes.clone();
118
3553

            
119
3553
        Box::pin(async move {
120
3553
            let token = match sylvia_http::parse_header_auth(&req) {
121
2
                Err(e) => return Ok(e.into_response()),
122
3551
                Ok(token) => match token {
123
                    None => {
124
2
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
125
2
                        return Ok(e.into_response());
126
                    }
127
3549
                    Some(token) => token,
128
                },
129
            };
130

            
131
3549
            let token_req = match client
132
3549
                .request(reqwest::Method::GET, auth_uri.as_str())
133
3549
                .header(reqwest::header::AUTHORIZATION, token.as_str())
134
3549
                .build()
135
            {
136
                Err(e) => {
137
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
138
                    return Ok(e.into_response());
139
                }
140
3549
                Ok(req) => req,
141
            };
142
16991
            let resp = match client.execute(token_req).await {
143
2
                Err(e) => {
144
2
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
145
2
                    return Ok(e.into_response());
146
                }
147
3547
                Ok(resp) => match resp.status() {
148
                    reqwest::StatusCode::UNAUTHORIZED => {
149
134
                        return Ok(ErrResp::ErrAuth(None).into_response())
150
                    }
151
3413
                    reqwest::StatusCode::OK => resp,
152
                    _ => {
153
                        let e = ErrResp::ErrIntMsg(Some(format!(
154
                            "auth error with status code: {}",
155
                            resp.status()
156
                        )));
157
                        return Ok(e.into_response());
158
                    }
159
                },
160
            };
161
3413
            let token_info = match resp.json::<GetTokenInfo>().await {
162
                Err(e) => {
163
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
164
                    return Ok(e.into_response());
165
                }
166
3413
                Ok(info) => info,
167
            };
168

            
169
3413
            if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
170
3411
                if api_roles.len() > 0 {
171
19
                    let roles: HashSet<&str> = token_info
172
19
                        .data
173
19
                        .roles
174
19
                        .iter()
175
19
                        .filter(|(_, &v)| v)
176
19
                        .map(|(k, _)| k.as_str())
177
19
                        .collect();
178
19
                    if api_roles.is_disjoint(&roles) {
179
8
                        let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
180
8
                        return Ok(e.into_response());
181
11
                    }
182
3392
                }
183
3403
                if api_scopes.len() > 0 {
184
4
                    let api_scopes: HashSet<&str> = api_scopes.iter().map(|s| s.as_str()).collect();
185
4
                    let scopes: HashSet<&str> =
186
4
                        token_info.data.scopes.iter().map(|s| s.as_str()).collect();
187
4
                    if api_scopes.is_disjoint(&scopes) {
188
2
                        let e = ErrResp::ErrPerm(Some("invalid scope".to_string()));
189
2
                        return Ok(e.into_response());
190
2
                    }
191
3399
                }
192
2
            }
193

            
194
3403
            let mut split = token.split_whitespace();
195
3403
            split.next(); // skip "Bearer".
196
3403
            let token = match split.next() {
197
                None => {
198
                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
199
                    return Ok(e.into_response());
200
                }
201
3403
                Some(token) => token.to_string(),
202
3403
            };
203
3403

            
204
3403
            req.extensions_mut().insert(GetTokenInfoData {
205
3403
                token,
206
3403
                user_id: token_info.data.user_id,
207
3403
                account: token_info.data.account,
208
3403
                roles: token_info.data.roles,
209
3403
                name: token_info.data.name,
210
3403
                client_id: token_info.data.client_id,
211
3403
                scopes: token_info.data.scopes,
212
3403
            });
213

            
214
11731
            let res = svc.call(req).await?;
215
3403
            Ok(res)
216
3553
        })
217
3553
    }
218
}