1
//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2

            
3
use std::{
4
    collections::HashMap,
5
    task::{Context, Poll},
6
};
7

            
8
use axum::{
9
    extract::Request,
10
    response::{IntoResponse, Response},
11
};
12
use futures::future::BoxFuture;
13
use reqwest;
14
use serde::{self, Deserialize};
15
use tower::{Layer, Service};
16

            
17
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
18

            
19
#[derive(Clone)]
20
pub struct GetTokenInfoData {
21
    /// The access token.
22
    pub token: String,
23
    pub user_id: String,
24
    pub account: String,
25
    pub roles: HashMap<String, bool>,
26
    pub name: String,
27
    pub client_id: String,
28
    pub scopes: Vec<String>,
29
}
30

            
31
#[derive(Clone)]
32
pub struct AuthService {
33
    client: reqwest::Client,
34
    auth_uri: String,
35
}
36

            
37
#[derive(Clone)]
38
pub struct AuthMiddleware<S> {
39
    client: reqwest::Client,
40
    auth_uri: String,
41
    service: S,
42
}
43

            
44
/// The user/client information of the token.
45
#[derive(Clone, Deserialize)]
46
struct GetTokenInfo {
47
    data: GetTokenInfoDataInner,
48
}
49

            
50
#[derive(Clone, Deserialize)]
51
struct GetTokenInfoDataInner {
52
    #[serde(rename = "userId")]
53
    pub user_id: String,
54
    pub account: String,
55
    pub roles: HashMap<String, bool>,
56
    pub name: String,
57
    #[serde(rename = "clientId")]
58
    pub client_id: String,
59
    pub scopes: Vec<String>,
60
}
61

            
62
impl AuthService {
63
25276
    pub fn new(client: reqwest::Client, auth_uri: String) -> Self {
64
25276
        AuthService { client, auth_uri }
65
25276
    }
66
}
67

            
68
impl<S> Layer<S> for AuthService {
69
    type Service = AuthMiddleware<S>;
70

            
71
50600
    fn layer(&self, inner: S) -> Self::Service {
72
50600
        AuthMiddleware {
73
50600
            client: self.client.clone(),
74
50600
            auth_uri: self.auth_uri.clone(),
75
50600
            service: inner,
76
50600
        }
77
50600
    }
78
}
79

            
80
impl<S> Service<Request> for AuthMiddleware<S>
81
where
82
    S: Service<Request, Response = Response> + Clone + Send + 'static,
83
    S::Future: Send + 'static,
84
{
85
    type Response = S::Response;
86
    type Error = S::Error;
87
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
88

            
89
2544
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90
2544
        self.service.poll_ready(cx)
91
2544
    }
92

            
93
2544
    fn call(&mut self, mut req: Request) -> Self::Future {
94
2544
        let mut svc = self.service.clone();
95
2544
        let client = self.client.clone();
96
2544
        let auth_uri = self.auth_uri.clone();
97

            
98
2544
        Box::pin(async move {
99
2544
            let token = match sylvia_http::parse_header_auth(&req) {
100
4
                Err(e) => return Ok(e.into_response()),
101
2540
                Ok(token) => match token {
102
                    None => {
103
4
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
104
4
                        return Ok(e.into_response());
105
                    }
106
2536
                    Some(token) => token,
107
                },
108
            };
109

            
110
2536
            let token_req = match client
111
2536
                .request(reqwest::Method::GET, auth_uri.as_str())
112
2536
                .header(reqwest::header::AUTHORIZATION, token.as_str())
113
2536
                .build()
114
            {
115
                Err(e) => {
116
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
117
                    return Ok(e.into_response());
118
                }
119
2536
                Ok(req) => req,
120
            };
121
2536
            let resp = match client.execute(token_req).await {
122
4
                Err(e) => {
123
4
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
124
4
                    return Ok(e.into_response());
125
                }
126
2532
                Ok(resp) => match resp.status() {
127
                    reqwest::StatusCode::UNAUTHORIZED => {
128
44
                        return Ok(ErrResp::ErrAuth(None).into_response());
129
                    }
130
2488
                    reqwest::StatusCode::OK => resp,
131
                    _ => {
132
                        let e = ErrResp::ErrIntMsg(Some(format!(
133
                            "auth error with status code: {}",
134
                            resp.status()
135
                        )));
136
                        return Ok(e.into_response());
137
                    }
138
                },
139
            };
140
2488
            let token_info = match resp.json::<GetTokenInfo>().await {
141
                Err(e) => {
142
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
143
                    return Ok(e.into_response());
144
                }
145
2488
                Ok(info) => info,
146
            };
147

            
148
2488
            let mut split = token.split_whitespace();
149
2488
            split.next(); // skip "Bearer".
150
2488
            let token = match split.next() {
151
                None => {
152
                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
153
                    return Ok(e.into_response());
154
                }
155
2488
                Some(token) => token.to_string(),
156
            };
157

            
158
2488
            req.extensions_mut().insert(GetTokenInfoData {
159
2488
                token,
160
2488
                user_id: token_info.data.user_id,
161
2488
                account: token_info.data.account,
162
2488
                roles: token_info.data.roles,
163
2488
                name: token_info.data.name,
164
2488
                client_id: token_info.data.client_id,
165
2488
                scopes: token_info.data.scopes,
166
2488
            });
167

            
168
2488
            let res = svc.call(req).await?;
169
2488
            Ok(res)
170
2544
        })
171
2544
    }
172
}