1
//! Provides the operation log middleware by sending requests to the data channel.
2

            
3
use std::{
4
    collections::HashMap,
5
    net::SocketAddr,
6
    task::{Context, Poll},
7
};
8

            
9
use axum::{
10
    body::{self, Body},
11
    extract::{ConnectInfo, Request},
12
    http::Method,
13
    response::{IntoResponse, Response},
14
};
15
use chrono::Utc;
16
use futures::future::BoxFuture;
17
use reqwest;
18
use serde::{self, Deserialize, Serialize};
19
use serde_json::{Map, Value};
20
use tower::{Layer, Service};
21

            
22
use general_mq::{queue::GmqQueue, Queue};
23
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http, strings};
24

            
25
#[derive(Clone)]
26
pub struct GetTokenInfoData {
27
    pub token: String,
28
    pub user_id: String,
29
    pub account: String,
30
    pub roles: HashMap<String, bool>,
31
    pub name: String,
32
    pub client_id: String,
33
    pub scopes: Vec<String>,
34
}
35

            
36
#[derive(Clone)]
37
pub struct LogService {
38
    auth_uri: String,
39
    queue: Option<Queue>,
40
}
41

            
42
#[derive(Clone)]
43
pub struct LogMiddleware<S> {
44
    client: reqwest::Client,
45
    auth_uri: String,
46
    queue: Option<Queue>,
47
    service: S,
48
}
49

            
50
/// The user/client information of the token.
51
12
#[derive(Deserialize)]
52
struct GetTokenInfo {
53
    data: GetTokenInfoDataInner,
54
}
55

            
56
42
#[derive(Deserialize)]
57
struct GetTokenInfoDataInner {
58
    #[serde(rename = "userId")]
59
    user_id: String,
60
    #[serde(rename = "account")]
61
    _account: String,
62
    #[serde(rename = "roles")]
63
    _roles: HashMap<String, bool>,
64
    #[serde(rename = "name")]
65
    _name: String,
66
    #[serde(rename = "clientId")]
67
    client_id: String,
68
    #[serde(rename = "scopes")]
69
    _scopes: Vec<String>,
70
}
71

            
72
#[derive(Serialize)]
73
struct SendDataMsg {
74
    kind: String,
75
    data: SendDataKind,
76
}
77

            
78
#[derive(Serialize)]
79
#[serde(untagged)]
80
enum SendDataKind {
81
    Operation {
82
        #[serde(rename = "dataId")]
83
        data_id: String,
84
        #[serde(rename = "reqTime")]
85
        req_time: String,
86
        #[serde(rename = "resTime")]
87
        res_time: String,
88
        #[serde(rename = "latencyMs")]
89
        latency_ms: i64,
90
        status: isize,
91
        #[serde(rename = "sourceIp")]
92
        source_ip: String,
93
        method: String,
94
        path: String,
95
        #[serde(skip_serializing_if = "Option::is_none")]
96
        body: Option<Map<String, Value>>,
97
        #[serde(rename = "userId")]
98
        user_id: String,
99
        #[serde(rename = "clientId")]
100
        client_id: String,
101
        #[serde(rename = "errCode", skip_serializing_if = "Option::is_none")]
102
        err_code: Option<String>,
103
        #[serde(rename = "errMessage", skip_serializing_if = "Option::is_none")]
104
        err_message: Option<String>,
105
    },
106
}
107

            
108
struct DataMsgKind;
109

            
110
const DATA_ID_RAND_LEN: usize = 12;
111

            
112
impl DataMsgKind {
113
    const OPERATION: &'static str = "operation";
114
}
115

            
116
impl LogService {
117
265
    pub fn new(auth_uri: String, queue: Option<Queue>) -> Self {
118
265
        LogService { auth_uri, queue }
119
265
    }
120
}
121

            
122
impl<S> Layer<S> for LogService {
123
    type Service = LogMiddleware<S>;
124

            
125
29661
    fn layer(&self, inner: S) -> Self::Service {
126
29661
        LogMiddleware {
127
29661
            client: reqwest::Client::new(),
128
29661
            auth_uri: self.auth_uri.clone(),
129
29661
            queue: self.queue.clone(),
130
29661
            service: inner,
131
29661
        }
132
29661
    }
133
}
134

            
135
impl<S> Service<Request> for LogMiddleware<S>
136
where
137
    S: Service<Request, Response = Response> + Clone + Send + 'static,
138
    S::Future: Send + 'static,
139
{
140
    type Response = S::Response;
141
    type Error = S::Error;
142
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
143

            
144
556
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145
556
        self.service.poll_ready(cx)
146
556
    }
147

            
148
556
    fn call(&mut self, req: Request) -> Self::Future {
149
556
        let mut svc = self.service.clone();
150
556
        let client = self.client.clone();
151
556
        let auth_uri = self.auth_uri.clone();
152
556
        let method = req.method().clone();
153
556
        let queue = match method {
154
259
            Method::DELETE | Method::PATCH | Method::POST | Method::PUT => self.queue.clone(),
155
297
            _ => None,
156
        };
157

            
158
556
        Box::pin(async move {
159
            // Only log for DELETE/PATCH/POST/PUT methods.
160
556
            let q = match queue.as_ref() {
161
                None => {
162
1683
                    let res = svc.call(req).await?;
163
541
                    return Ok(res);
164
                }
165
15
                Some(q) => q,
166
15
            };
167
15

            
168
15
            let req_time = Utc::now();
169

            
170
            // Collect body (and generate a new stream) and information for logging the operation.
171
15
            let source_ip = match req.extensions().get::<ConnectInfo<SocketAddr>>() {
172
15
                None => "".to_string(),
173
                Some(info) => info.0.to_string(),
174
            };
175
15
            let method = req.method().to_string();
176
15
            let path = req.uri().to_string();
177
15
            let auth_token = match sylvia_http::parse_header_auth(&req) {
178
                Err(_) => None,
179
15
                Ok(token) => match token {
180
3
                    None => None,
181
12
                    Some(token) => Some(token),
182
                },
183
            };
184
15
            let (parts, body) = req.into_parts();
185
15
            let body_bytes = match body::to_bytes(body, usize::MAX).await {
186
                Err(e) => {
187
                    let e = format!("convert body error: {}", e);
188
                    return Ok(ErrResp::ErrParam(Some(e)).into_response());
189
                }
190
15
                Ok(body_bytes) => body_bytes,
191
            };
192
15
            let log_body = match serde_json::from_slice::<Map<String, Value>>(&body_bytes.to_vec())
193
            {
194
9
                Err(_) => None,
195
6
                Ok(mut body) => {
196
                    // Remove secret contents.
197
6
                    if let Some(data) = body.get_mut("data") {
198
6
                        if let Some(data) = data.as_object_mut() {
199
6
                            if data.contains_key("password") {
200
3
                                data.insert("password".to_string(), Value::String("".to_string()));
201
3
                            }
202
                        }
203
                    }
204
6
                    Some(body)
205
                }
206
            };
207
15
            let req = Request::from_parts(parts, Body::from(body_bytes));
208

            
209
            // Do the request.
210
15
            let res = svc.call(req).await?;
211
15
            let (err_code, err_message) = match res.status().is_success() {
212
                false => {
213
                    // TODO: extract (code, message) pair of response body.
214
9
                    (None, None)
215
                }
216
6
                true => (None, None),
217
            };
218

            
219
            // Send log.
220
15
            let auth_token = match auth_token.as_ref() {
221
3
                None => return Ok(res),
222
12
                Some(auth_token) => auth_token,
223
            };
224
48
            let token_info = match get_token(client, auth_token.as_str(), auth_uri.as_str()).await {
225
6
                Err(_) => return Ok(res),
226
6
                Ok(token_info) => token_info,
227
6
            };
228
6
            let res_time = Utc::now();
229
6
            let msg = SendDataMsg {
230
6
                kind: DataMsgKind::OPERATION.to_string(),
231
6
                data: SendDataKind::Operation {
232
6
                    data_id: strings::random_id(&req_time, DATA_ID_RAND_LEN),
233
6
                    req_time: strings::time_str(&req_time),
234
6
                    res_time: strings::time_str(&res_time),
235
6
                    latency_ms: res_time.timestamp_millis() - req_time.timestamp_millis(),
236
6
                    status: res.status().as_u16() as isize,
237
6
                    source_ip,
238
6
                    method,
239
6
                    path,
240
6
                    body: log_body,
241
6
                    user_id: token_info.data.user_id,
242
6
                    client_id: token_info.data.client_id,
243
6
                    err_code,
244
6
                    err_message,
245
6
                },
246
6
            };
247
6
            if let Ok(payload) = serde_json::to_vec(&msg) {
248
6
                let _ = q.send_msg(payload).await;
249
            }
250
6
            Ok(res)
251
556
        })
252
556
    }
253
}
254

            
255
12
async fn get_token(
256
12
    client: reqwest::Client,
257
12
    auth_token: &str,
258
12
    auth_uri: &str,
259
12
) -> Result<GetTokenInfo, String> {
260
12
    let token_req = match client
261
12
        .request(reqwest::Method::GET, auth_uri)
262
12
        .header(reqwest::header::AUTHORIZATION, auth_token)
263
12
        .build()
264
    {
265
        Err(e) => return Err(format!("request auth error: {}", e)),
266
12
        Ok(req) => req,
267
    };
268
48
    let resp = match client.execute(token_req).await {
269
        Err(e) => return Err(format!("auth error: {}", e)),
270
12
        Ok(resp) => match resp.status() {
271
6
            reqwest::StatusCode::UNAUTHORIZED => return Err("unauthorized".to_string()),
272
6
            reqwest::StatusCode::OK => resp,
273
            _ => return Err(format!("auth error with status code: {}", resp.status())),
274
        },
275
    };
276
6
    match resp.json::<GetTokenInfo>().await {
277
        Err(e) => Err(format!("read auth body error: {}", e)),
278
6
        Ok(info) => Ok(info),
279
    }
280
12
}