1
//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2

            
3
use std::{
4
    collections::HashMap,
5
    task::{Context, Poll},
6
};
7

            
8
use axum::{
9
    extract::Request,
10
    response::{IntoResponse, Response},
11
};
12
use futures::future::BoxFuture;
13
use reqwest;
14
use serde::{self, Deserialize};
15
use tower::{Layer, Service};
16

            
17
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
18

            
19
#[derive(Clone)]
20
pub struct GetTokenInfoData {
21
    /// The access token.
22
    pub token: String,
23
    pub user_id: String,
24
    pub account: String,
25
    pub roles: HashMap<String, bool>,
26
    pub name: String,
27
    pub client_id: String,
28
    pub scopes: Vec<String>,
29
}
30

            
31
#[derive(Clone)]
32
pub struct AuthService {
33
    auth_uri: String,
34
}
35

            
36
#[derive(Clone)]
37
pub struct AuthMiddleware<S> {
38
    client: reqwest::Client,
39
    auth_uri: String,
40
    service: S,
41
}
42

            
43
/// The user/client information of the token.
44
#[derive(Clone, Deserialize)]
45
struct GetTokenInfo {
46
    data: GetTokenInfoDataInner,
47
}
48

            
49
#[derive(Clone, Deserialize)]
50
struct GetTokenInfoDataInner {
51
    #[serde(rename = "userId")]
52
    pub user_id: String,
53
    pub account: String,
54
    pub roles: HashMap<String, bool>,
55
    pub name: String,
56
    #[serde(rename = "clientId")]
57
    pub client_id: String,
58
    pub scopes: Vec<String>,
59
}
60

            
61
impl AuthService {
62
12638
    pub fn new(auth_uri: String) -> Self {
63
12638
        AuthService { auth_uri }
64
12638
    }
65
}
66

            
67
impl<S> Layer<S> for AuthService {
68
    type Service = AuthMiddleware<S>;
69

            
70
25300
    fn layer(&self, inner: S) -> Self::Service {
71
25300
        AuthMiddleware {
72
25300
            client: reqwest::Client::new(),
73
25300
            auth_uri: self.auth_uri.clone(),
74
25300
            service: inner,
75
25300
        }
76
25300
    }
77
}
78

            
79
impl<S> Service<Request> for AuthMiddleware<S>
80
where
81
    S: Service<Request, Response = Response> + Clone + Send + 'static,
82
    S::Future: Send + 'static,
83
{
84
    type Response = S::Response;
85
    type Error = S::Error;
86
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
87

            
88
1272
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89
1272
        self.service.poll_ready(cx)
90
1272
    }
91

            
92
1272
    fn call(&mut self, mut req: Request) -> Self::Future {
93
1272
        let mut svc = self.service.clone();
94
1272
        let client = self.client.clone();
95
1272
        let auth_uri = self.auth_uri.clone();
96
1272

            
97
1272
        Box::pin(async move {
98
1272
            let token = match sylvia_http::parse_header_auth(&req) {
99
2
                Err(e) => return Ok(e.into_response()),
100
1270
                Ok(token) => match token {
101
                    None => {
102
2
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
103
2
                        return Ok(e.into_response());
104
                    }
105
1268
                    Some(token) => token,
106
                },
107
            };
108

            
109
1268
            let token_req = match client
110
1268
                .request(reqwest::Method::GET, auth_uri.as_str())
111
1268
                .header(reqwest::header::AUTHORIZATION, token.as_str())
112
1268
                .build()
113
            {
114
                Err(e) => {
115
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
116
                    return Ok(e.into_response());
117
                }
118
1268
                Ok(req) => req,
119
            };
120
1268
            let resp = match client.execute(token_req).await {
121
2
                Err(e) => {
122
2
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
123
2
                    return Ok(e.into_response());
124
                }
125
1266
                Ok(resp) => match resp.status() {
126
                    reqwest::StatusCode::UNAUTHORIZED => {
127
22
                        return Ok(ErrResp::ErrAuth(None).into_response())
128
                    }
129
1244
                    reqwest::StatusCode::OK => resp,
130
                    _ => {
131
                        let e = ErrResp::ErrIntMsg(Some(format!(
132
                            "auth error with status code: {}",
133
                            resp.status()
134
                        )));
135
                        return Ok(e.into_response());
136
                    }
137
                },
138
            };
139
1244
            let token_info = match resp.json::<GetTokenInfo>().await {
140
                Err(e) => {
141
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
142
                    return Ok(e.into_response());
143
                }
144
1244
                Ok(info) => info,
145
1244
            };
146
1244

            
147
1244
            let mut split = token.split_whitespace();
148
1244
            split.next(); // skip "Bearer".
149
1244
            let token = match split.next() {
150
                None => {
151
                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
152
                    return Ok(e.into_response());
153
                }
154
1244
                Some(token) => token.to_string(),
155
1244
            };
156
1244

            
157
1244
            req.extensions_mut().insert(GetTokenInfoData {
158
1244
                token,
159
1244
                user_id: token_info.data.user_id,
160
1244
                account: token_info.data.account,
161
1244
                roles: token_info.data.roles,
162
1244
                name: token_info.data.name,
163
1244
                client_id: token_info.data.client_id,
164
1244
                scopes: token_info.data.scopes,
165
1244
            });
166

            
167
1244
            let res = svc.call(req).await?;
168
1244
            Ok(res)
169
1272
        })
170
1272
    }
171
}