1
//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2

            
3
use std::{
4
    cell::RefCell,
5
    collections::HashMap,
6
    rc::Rc,
7
    task::{Context, Poll},
8
};
9

            
10
use actix_service::{Service, Transform};
11
use actix_web::{
12
    body::BoxBody,
13
    dev::{ServiceRequest, ServiceResponse},
14
    Error, HttpMessage,
15
};
16
use futures::future::{self, LocalBoxFuture, Ready};
17
use reqwest;
18
use serde::{self, Deserialize};
19

            
20
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
21

            
22
/// The information contains [`GetTokenInfoData`] and access token.
23
#[derive(Clone)]
24
pub struct FullTokenInfo {
25
    pub token: String,
26
    pub info: GetTokenInfoData,
27
}
28

            
29
/// The user/client information of the token.
30
3732
#[derive(Clone, Deserialize)]
31
pub struct GetTokenInfo {
32
    pub data: GetTokenInfoData,
33
}
34

            
35
16172
#[derive(Clone, Deserialize)]
36
pub struct GetTokenInfoData {
37
    #[serde(rename = "userId")]
38
    pub user_id: String,
39
    pub account: String,
40
    pub roles: HashMap<String, bool>,
41
    pub name: String,
42
    #[serde(rename = "clientId")]
43
    pub client_id: String,
44
    pub scopes: Vec<String>,
45
}
46

            
47
pub struct AuthService {
48
    auth_uri: String,
49
}
50

            
51
pub struct AuthMiddleware<S> {
52
    client: reqwest::Client,
53
    auth_uri: String,
54
    service: Rc<RefCell<S>>,
55
}
56

            
57
impl AuthService {
58
12638
    pub fn new(auth_uri: String) -> Self {
59
12638
        AuthService { auth_uri }
60
12638
    }
61
}
62

            
63
impl<S> Transform<S, ServiceRequest> for AuthService
64
where
65
    S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + 'static,
66
    S::Future: 'static,
67
{
68
    type Response = ServiceResponse<BoxBody>;
69
    type Error = Error;
70
    type Transform = AuthMiddleware<S>;
71
    type InitError = ();
72
    type Future = Ready<Result<Self::Transform, Self::InitError>>;
73

            
74
12628
    fn new_transform(&self, service: S) -> Self::Future {
75
12628
        future::ok(AuthMiddleware {
76
12628
            client: reqwest::Client::new(),
77
12628
            auth_uri: self.auth_uri.clone(),
78
12628
            service: Rc::new(RefCell::new(service)),
79
12628
        })
80
12628
    }
81
}
82

            
83
impl<S> Service<ServiceRequest> for AuthMiddleware<S>
84
where
85
    S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + 'static,
86
    S::Future: 'static,
87
{
88
    type Response = ServiceResponse<BoxBody>;
89
    type Error = Error;
90
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
91

            
92
    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93
        self.service.poll_ready(cx)
94
    }
95

            
96
1270
    fn call(&self, mut req: ServiceRequest) -> Self::Future {
97
1270
        let svc = self.service.clone();
98
1270
        let client = self.client.clone();
99
1270
        let auth_uri = self.auth_uri.clone();
100
1270

            
101
1270
        Box::pin(async move {
102
1270
            let (http_req, _) = req.parts_mut();
103
1270
            let token = match sylvia_http::parse_header_auth(&http_req) {
104
                Err(e) => {
105
                    return Ok(ServiceResponse::from_err(e, http_req.clone()));
106
                }
107
1270
                Ok(token) => match token {
108
                    None => {
109
2
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
110
2
                        return Ok(ServiceResponse::from_err(e, http_req.clone()));
111
                    }
112
1268
                    Some(token) => token,
113
                },
114
            };
115

            
116
1268
            let token_req = match client
117
1268
                .request(reqwest::Method::GET, auth_uri.as_str())
118
1268
                .header(reqwest::header::AUTHORIZATION, token.as_str())
119
1268
                .build()
120
            {
121
                Err(e) => {
122
                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
123
                    return Ok(ServiceResponse::from_err(e, http_req.clone()));
124
                }
125
1268
                Ok(req) => req,
126
            };
127
6338
            let resp = match client.execute(token_req).await {
128
2
                Err(e) => {
129
2
                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
130
2
                    return Ok(ServiceResponse::from_err(e, http_req.clone()));
131
                }
132
1266
                Ok(resp) => match resp.status() {
133
                    reqwest::StatusCode::UNAUTHORIZED => {
134
22
                        let e = ErrResp::ErrAuth(None);
135
22
                        return Ok(ServiceResponse::from_err(e, http_req.clone()));
136
                    }
137
1244
                    reqwest::StatusCode::OK => resp,
138
                    _ => {
139
                        let e = ErrResp::ErrIntMsg(Some(format!(
140
                            "auth error with status code: {}",
141
                            resp.status()
142
                        )));
143
                        return Ok(ServiceResponse::from_err(e, http_req.clone()));
144
                    }
145
                },
146
            };
147
1244
            let token_info = match resp.json::<GetTokenInfo>().await {
148
                Err(e) => {
149
                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
150
                    return Ok(ServiceResponse::from_err(e, http_req.clone()));
151
                }
152
1244
                Ok(info) => info,
153
1244
            };
154
1244

            
155
1244
            req.extensions_mut().insert(FullTokenInfo {
156
1244
                token,
157
1244
                info: token_info.data,
158
1244
            });
159

            
160
2230
            let res = svc.call(req).await?;
161
1244
            Ok(res)
162
1270
        })
163
1270
    }
164
}