1
//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2

            
3
use std::{
4
    collections::{HashMap, HashSet},
5
    task::{Context, Poll},
6
};
7

            
8
use axum::{
9
    extract::Request,
10
    http::Method,
11
    response::{IntoResponse, Response},
12
};
13
use futures::future::BoxFuture;
14
use reqwest;
15
use serde::{self, Deserialize};
16
use tower::{Layer, Service};
17

            
18
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
19

            
20
pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
21
type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
22

            
23
#[derive(Clone)]
24
pub struct GetTokenInfoData {
25
    /// The access token.
26
    pub token: String,
27
    pub user_id: String,
28
    pub account: String,
29
    pub roles: HashMap<String, bool>,
30
    pub name: String,
31
    pub client_id: String,
32
    pub scopes: Vec<String>,
33
}
34

            
35
#[derive(Clone)]
36
pub struct AuthService {
37
    auth_uri: String,
38
    role_scopes: HashMap<Method, RoleScopeType>,
39
}
40

            
41
#[derive(Clone)]
42
pub struct AuthMiddleware<S> {
43
    client: reqwest::Client,
44
    auth_uri: String,
45
    role_scopes: HashMap<Method, RoleScopeInner>,
46
    service: S,
47
}
48

            
49
/// The user/client information of the token.
50
#[derive(Deserialize)]
51
struct GetTokenInfo {
52
    data: GetTokenInfoDataInner,
53
}
54

            
55
#[derive(Deserialize)]
56
struct GetTokenInfoDataInner {
57
    #[serde(rename = "userId")]
58
    user_id: String,
59
    account: String,
60
    roles: HashMap<String, bool>,
61
    name: String,
62
    #[serde(rename = "clientId")]
63
    client_id: String,
64
    scopes: Vec<String>,
65
}
66

            
67
impl AuthService {
68
240572
    pub fn new(auth_uri: String, role_scopes: HashMap<Method, RoleScopeType>) -> Self {
69
240572
        AuthService {
70
240572
            role_scopes,
71
240572
            auth_uri,
72
240572
        }
73
240572
    }
74
}
75

            
76
impl<S> Layer<S> for AuthService {
77
    type Service = AuthMiddleware<S>;
78

            
79
534668
    fn layer(&self, inner: S) -> Self::Service {
80
534668
        let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
81
828652
        for (k, (r, s)) in self.role_scopes.iter() {
82
828652
            role_scopes.insert(
83
828652
                k.clone(),
84
828652
                (
85
828652
                    r.iter().map(|&r| r).collect(),
86
828652
                    s.iter().map(|s| s.clone()).collect(),
87
828652
                ),
88
828652
            );
89
828652
        }
90

            
91
534668
        AuthMiddleware {
92
534668
            client: reqwest::Client::new(),
93
534668
            auth_uri: self.auth_uri.clone(),
94
534668
            role_scopes,
95
534668
            service: inner,
96
534668
        }
97
534668
    }
98
}
99

            
100
impl<S> Service<Request> for AuthMiddleware<S>
101
where
102
    S: Service<Request, Response = Response> + Clone + Send + 'static,
103
    S::Future: Send + 'static,
104
{
105
    type Response = S::Response;
106
    type Error = S::Error;
107
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
108

            
109
7106
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110
7106
        self.service.poll_ready(cx)
111
7106
    }
112

            
113
7106
    fn call(&mut self, mut req: Request) -> Self::Future {
114
7106
        let mut svc = self.service.clone();
115
7106
        let client = self.client.clone();
116
7106
        let auth_uri = self.auth_uri.clone();
117
7106
        let role_scopes = self.role_scopes.clone();
118
7106

            
119
7106
        Box::pin(async move {
120
7106
            let token = match sylvia_http::parse_header_auth(&req) {
121
4
                Err(e) => return Ok(e.into_response()),
122
7102
                Ok(token) => match token {
123
                    None => {
124
4
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
125
4
                        return Ok(e.into_response());
126
                    }
127
7098
                    Some(token) => token,
128
                },
129
            };
130

            
131
7098
            let token_req = match client
132
7098
                .request(reqwest::Method::GET, auth_uri.as_str())
133
7098
                .header(reqwest::header::AUTHORIZATION, token.as_str())
134
7098
                .build()
135
            {
136
                Err(e) => {
137
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
138
                    return Ok(e.into_response());
139
                }
140
7098
                Ok(req) => req,
141
            };
142
7098
            let resp = match client.execute(token_req).await {
143
4
                Err(e) => {
144
4
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
145
4
                    return Ok(e.into_response());
146
                }
147
7094
                Ok(resp) => match resp.status() {
148
                    reqwest::StatusCode::UNAUTHORIZED => {
149
268
                        return Ok(ErrResp::ErrAuth(None).into_response());
150
                    }
151
6826
                    reqwest::StatusCode::OK => resp,
152
                    _ => {
153
                        let e = ErrResp::ErrIntMsg(Some(format!(
154
                            "auth error with status code: {}",
155
                            resp.status()
156
                        )));
157
                        return Ok(e.into_response());
158
                    }
159
                },
160
            };
161
6826
            let token_info = match resp.json::<GetTokenInfo>().await {
162
                Err(e) => {
163
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
164
                    return Ok(e.into_response());
165
                }
166
6826
                Ok(info) => info,
167
            };
168

            
169
6826
            if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
170
6822
                if api_roles.len() > 0 {
171
38
                    let roles: HashSet<&str> = token_info
172
38
                        .data
173
38
                        .roles
174
38
                        .iter()
175
38
                        .filter(|(_, v)| **v)
176
38
                        .map(|(k, _)| k.as_str())
177
38
                        .collect();
178
38
                    if api_roles.is_disjoint(&roles) {
179
16
                        let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
180
16
                        return Ok(e.into_response());
181
22
                    }
182
6784
                }
183
6806
                if api_scopes.len() > 0 {
184
8
                    let api_scopes: HashSet<&str> = api_scopes.iter().map(|s| s.as_str()).collect();
185
8
                    let scopes: HashSet<&str> =
186
8
                        token_info.data.scopes.iter().map(|s| s.as_str()).collect();
187
8
                    if api_scopes.is_disjoint(&scopes) {
188
4
                        let e = ErrResp::ErrPerm(Some("invalid scope".to_string()));
189
4
                        return Ok(e.into_response());
190
4
                    }
191
6798
                }
192
4
            }
193

            
194
6806
            let mut split = token.split_whitespace();
195
6806
            split.next(); // skip "Bearer".
196
6806
            let token = match split.next() {
197
                None => {
198
                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
199
                    return Ok(e.into_response());
200
                }
201
6806
                Some(token) => token.to_string(),
202
6806
            };
203
6806

            
204
6806
            req.extensions_mut().insert(GetTokenInfoData {
205
6806
                token,
206
6806
                user_id: token_info.data.user_id,
207
6806
                account: token_info.data.account,
208
6806
                roles: token_info.data.roles,
209
6806
                name: token_info.data.name,
210
6806
                client_id: token_info.data.client_id,
211
6806
                scopes: token_info.data.scopes,
212
6806
            });
213

            
214
6806
            let res = svc.call(req).await?;
215
6806
            Ok(res)
216
7106
        })
217
7106
    }
218
}