1
//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2

            
3
use std::{
4
    collections::{HashMap, HashSet},
5
    task::{Context, Poll},
6
};
7

            
8
use axum::{
9
    extract::Request,
10
    http::Method,
11
    response::{IntoResponse, Response},
12
};
13
use futures::future::BoxFuture;
14
use reqwest;
15
use serde::{self, Deserialize};
16
use tower::{Layer, Service};
17

            
18
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
19

            
20
pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
21
type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
22

            
23
#[derive(Clone)]
24
pub struct GetTokenInfoData {
25
    /// The access token.
26
    pub token: String,
27
    pub user_id: String,
28
    pub account: String,
29
    pub roles: HashMap<String, bool>,
30
    pub name: String,
31
    pub client_id: String,
32
    pub scopes: Vec<String>,
33
}
34

            
35
#[derive(Clone)]
36
pub struct AuthService {
37
    client: reqwest::Client,
38
    auth_uri: String,
39
    role_scopes: HashMap<Method, RoleScopeType>,
40
}
41

            
42
#[derive(Clone)]
43
pub struct AuthMiddleware<S> {
44
    client: reqwest::Client,
45
    auth_uri: String,
46
    role_scopes: HashMap<Method, RoleScopeInner>,
47
    service: S,
48
}
49

            
50
/// The user/client information of the token.
51
#[derive(Deserialize)]
52
struct GetTokenInfo {
53
    data: GetTokenInfoDataInner,
54
}
55

            
56
#[derive(Deserialize)]
57
struct GetTokenInfoDataInner {
58
    #[serde(rename = "userId")]
59
    user_id: String,
60
    account: String,
61
    roles: HashMap<String, bool>,
62
    name: String,
63
    #[serde(rename = "clientId")]
64
    client_id: String,
65
    scopes: Vec<String>,
66
}
67

            
68
impl AuthService {
69
240572
    pub fn new(
70
240572
        client: reqwest::Client,
71
240572
        auth_uri: String,
72
240572
        role_scopes: HashMap<Method, RoleScopeType>,
73
240572
    ) -> Self {
74
240572
        AuthService {
75
240572
            client,
76
240572
            auth_uri,
77
240572
            role_scopes,
78
240572
        }
79
240572
    }
80
}
81

            
82
impl<S> Layer<S> for AuthService {
83
    type Service = AuthMiddleware<S>;
84

            
85
534668
    fn layer(&self, inner: S) -> Self::Service {
86
534668
        let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
87
828652
        for (k, (r, s)) in self.role_scopes.iter() {
88
828652
            role_scopes.insert(
89
828652
                k.clone(),
90
                (
91
828652
                    r.iter().map(|&r| r).collect(),
92
828652
                    s.iter().map(|s| s.clone()).collect(),
93
                ),
94
            );
95
        }
96

            
97
534668
        AuthMiddleware {
98
534668
            client: self.client.clone(),
99
534668
            auth_uri: self.auth_uri.clone(),
100
534668
            role_scopes,
101
534668
            service: inner,
102
534668
        }
103
534668
    }
104
}
105

            
106
impl<S> Service<Request> for AuthMiddleware<S>
107
where
108
    S: Service<Request, Response = Response> + Clone + Send + 'static,
109
    S::Future: Send + 'static,
110
{
111
    type Response = S::Response;
112
    type Error = S::Error;
113
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
114

            
115
7106
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116
7106
        self.service.poll_ready(cx)
117
7106
    }
118

            
119
7106
    fn call(&mut self, mut req: Request) -> Self::Future {
120
7106
        let mut svc = self.service.clone();
121
7106
        let client = self.client.clone();
122
7106
        let auth_uri = self.auth_uri.clone();
123
7106
        let role_scopes = self.role_scopes.clone();
124

            
125
7106
        Box::pin(async move {
126
7106
            let token = match sylvia_http::parse_header_auth(&req) {
127
4
                Err(e) => return Ok(e.into_response()),
128
7102
                Ok(token) => match token {
129
                    None => {
130
4
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
131
4
                        return Ok(e.into_response());
132
                    }
133
7098
                    Some(token) => token,
134
                },
135
            };
136

            
137
7098
            let token_req = match client
138
7098
                .request(reqwest::Method::GET, auth_uri.as_str())
139
7098
                .header(reqwest::header::AUTHORIZATION, token.as_str())
140
7098
                .build()
141
            {
142
                Err(e) => {
143
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
144
                    return Ok(e.into_response());
145
                }
146
7098
                Ok(req) => req,
147
            };
148
7098
            let resp = match client.execute(token_req).await {
149
4
                Err(e) => {
150
4
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
151
4
                    return Ok(e.into_response());
152
                }
153
7094
                Ok(resp) => match resp.status() {
154
                    reqwest::StatusCode::UNAUTHORIZED => {
155
268
                        return Ok(ErrResp::ErrAuth(None).into_response());
156
                    }
157
6826
                    reqwest::StatusCode::OK => resp,
158
                    _ => {
159
                        let e = ErrResp::ErrIntMsg(Some(format!(
160
                            "auth error with status code: {}",
161
                            resp.status()
162
                        )));
163
                        return Ok(e.into_response());
164
                    }
165
                },
166
            };
167
6826
            let token_info = match resp.json::<GetTokenInfo>().await {
168
                Err(e) => {
169
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
170
                    return Ok(e.into_response());
171
                }
172
6826
                Ok(info) => info,
173
            };
174

            
175
6826
            if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
176
6822
                if api_roles.len() > 0 {
177
38
                    let roles: HashSet<&str> = token_info
178
38
                        .data
179
38
                        .roles
180
38
                        .iter()
181
38
                        .filter(|(_, v)| **v)
182
38
                        .map(|(k, _)| k.as_str())
183
38
                        .collect();
184
38
                    if api_roles.is_disjoint(&roles) {
185
16
                        let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
186
16
                        return Ok(e.into_response());
187
22
                    }
188
6784
                }
189
6806
                if api_scopes.len() > 0 {
190
8
                    let api_scopes: HashSet<&str> = api_scopes.iter().map(|s| s.as_str()).collect();
191
8
                    let scopes: HashSet<&str> =
192
8
                        token_info.data.scopes.iter().map(|s| s.as_str()).collect();
193
8
                    if api_scopes.is_disjoint(&scopes) {
194
4
                        let e = ErrResp::ErrPerm(Some("invalid scope".to_string()));
195
4
                        return Ok(e.into_response());
196
4
                    }
197
6798
                }
198
4
            }
199

            
200
6806
            let mut split = token.split_whitespace();
201
6806
            split.next(); // skip "Bearer".
202
6806
            let token = match split.next() {
203
                None => {
204
                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
205
                    return Ok(e.into_response());
206
                }
207
6806
                Some(token) => token.to_string(),
208
            };
209

            
210
6806
            req.extensions_mut().insert(GetTokenInfoData {
211
6806
                token,
212
6806
                user_id: token_info.data.user_id,
213
6806
                account: token_info.data.account,
214
6806
                roles: token_info.data.roles,
215
6806
                name: token_info.data.name,
216
6806
                client_id: token_info.data.client_id,
217
6806
                scopes: token_info.data.scopes,
218
6806
            });
219

            
220
6806
            let res = svc.call(req).await?;
221
6806
            Ok(res)
222
7106
        })
223
7106
    }
224
}