1
//! Provides the operation log middleware by sending requests to the data channel.
2

            
3
use std::{
4
    collections::HashMap,
5
    net::SocketAddr,
6
    task::{Context, Poll},
7
};
8

            
9
use axum::{
10
    body::{self, Body},
11
    extract::{ConnectInfo, Request},
12
    http::Method,
13
    response::{IntoResponse, Response},
14
};
15
use chrono::Utc;
16
use futures::future::BoxFuture;
17
use reqwest;
18
use serde::{self, Deserialize, Serialize};
19
use serde_json::{Map, Value};
20
use tower::{Layer, Service};
21

            
22
use general_mq::{Queue, queue::GmqQueue};
23
use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http, strings};
24

            
25
#[derive(Clone)]
26
pub struct GetTokenInfoData {
27
    pub token: String,
28
    pub user_id: String,
29
    pub account: String,
30
    pub roles: HashMap<String, bool>,
31
    pub name: String,
32
    pub client_id: String,
33
    pub scopes: Vec<String>,
34
}
35

            
36
#[derive(Clone)]
37
pub struct LogService {
38
    client: reqwest::Client,
39
    auth_uri: String,
40
    queue: Option<Queue>,
41
}
42

            
43
#[derive(Clone)]
44
pub struct LogMiddleware<S> {
45
    client: reqwest::Client,
46
    auth_uri: String,
47
    queue: Option<Queue>,
48
    service: S,
49
}
50

            
51
/// The user/client information of the token.
52
#[derive(Deserialize)]
53
struct GetTokenInfo {
54
    data: GetTokenInfoDataInner,
55
}
56

            
57
#[derive(Deserialize)]
58
struct GetTokenInfoDataInner {
59
    #[serde(rename = "userId")]
60
    user_id: String,
61
    #[serde(rename = "account")]
62
    _account: String,
63
    #[serde(rename = "roles")]
64
    _roles: HashMap<String, bool>,
65
    #[serde(rename = "name")]
66
    _name: String,
67
    #[serde(rename = "clientId")]
68
    client_id: String,
69
    #[serde(rename = "scopes")]
70
    _scopes: Vec<String>,
71
}
72

            
73
#[derive(Serialize)]
74
struct SendDataMsg {
75
    kind: String,
76
    data: SendDataKind,
77
}
78

            
79
#[derive(Serialize)]
80
#[serde(untagged)]
81
enum SendDataKind {
82
    Operation {
83
        #[serde(rename = "dataId")]
84
        data_id: String,
85
        #[serde(rename = "reqTime")]
86
        req_time: String,
87
        #[serde(rename = "resTime")]
88
        res_time: String,
89
        #[serde(rename = "latencyMs")]
90
        latency_ms: i64,
91
        status: isize,
92
        #[serde(rename = "sourceIp")]
93
        source_ip: String,
94
        method: String,
95
        path: String,
96
        #[serde(skip_serializing_if = "Option::is_none")]
97
        body: Option<Map<String, Value>>,
98
        #[serde(rename = "userId")]
99
        user_id: String,
100
        #[serde(rename = "clientId")]
101
        client_id: String,
102
        #[serde(rename = "errCode", skip_serializing_if = "Option::is_none")]
103
        err_code: Option<String>,
104
        #[serde(rename = "errMessage", skip_serializing_if = "Option::is_none")]
105
        err_message: Option<String>,
106
    },
107
}
108

            
109
struct DataMsgKind;
110

            
111
const DATA_ID_RAND_LEN: usize = 12;
112

            
113
impl DataMsgKind {
114
    const OPERATION: &'static str = "operation";
115
}
116

            
117
impl LogService {
118
530
    pub fn new(client: reqwest::Client, auth_uri: String, queue: Option<Queue>) -> Self {
119
530
        LogService {
120
530
            client,
121
530
            auth_uri,
122
530
            queue,
123
530
        }
124
530
    }
125
}
126

            
127
impl<S> Layer<S> for LogService {
128
    type Service = LogMiddleware<S>;
129

            
130
59322
    fn layer(&self, inner: S) -> Self::Service {
131
59322
        LogMiddleware {
132
59322
            client: self.client.clone(),
133
59322
            auth_uri: self.auth_uri.clone(),
134
59322
            queue: self.queue.clone(),
135
59322
            service: inner,
136
59322
        }
137
59322
    }
138
}
139

            
140
impl<S> Service<Request> for LogMiddleware<S>
141
where
142
    S: Service<Request, Response = Response> + Clone + Send + 'static,
143
    S::Future: Send + 'static,
144
{
145
    type Response = S::Response;
146
    type Error = S::Error;
147
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
148

            
149
1100
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150
1100
        self.service.poll_ready(cx)
151
1100
    }
152

            
153
1100
    fn call(&mut self, req: Request) -> Self::Future {
154
1100
        let mut svc = self.service.clone();
155
1100
        let client = self.client.clone();
156
1100
        let auth_uri = self.auth_uri.clone();
157
1100
        let method = req.method().clone();
158
1100
        let queue = match method {
159
518
            Method::DELETE | Method::PATCH | Method::POST | Method::PUT => self.queue.clone(),
160
582
            _ => None,
161
        };
162

            
163
1100
        Box::pin(async move {
164
            // Only log for DELETE/PATCH/POST/PUT methods.
165
1100
            let q = match queue.as_ref() {
166
                None => {
167
1070
                    let res = svc.call(req).await?;
168
1070
                    return Ok(res);
169
                }
170
30
                Some(q) => q,
171
            };
172

            
173
30
            let req_time = Utc::now();
174

            
175
            // Collect body (and generate a new stream) and information for logging the operation.
176
30
            let source_ip = match req.extensions().get::<ConnectInfo<SocketAddr>>() {
177
30
                None => "".to_string(),
178
                Some(info) => info.0.to_string(),
179
            };
180
30
            let method = req.method().to_string();
181
30
            let path = req.uri().to_string();
182
30
            let auth_token = match sylvia_http::parse_header_auth(&req) {
183
                Err(_) => None,
184
30
                Ok(token) => match token {
185
6
                    None => None,
186
24
                    Some(token) => Some(token),
187
                },
188
            };
189
30
            let (parts, body) = req.into_parts();
190
30
            let body_bytes = match body::to_bytes(body, usize::MAX).await {
191
                Err(e) => {
192
                    let e = format!("convert body error: {}", e);
193
                    return Ok(ErrResp::ErrParam(Some(e)).into_response());
194
                }
195
30
                Ok(body_bytes) => body_bytes,
196
            };
197
30
            let log_body = match serde_json::from_slice::<Map<String, Value>>(&body_bytes.to_vec())
198
            {
199
18
                Err(_) => None,
200
12
                Ok(mut body) => {
201
                    // Remove secret contents.
202
12
                    if let Some(data) = body.get_mut("data") {
203
12
                        if let Some(data) = data.as_object_mut() {
204
12
                            if data.contains_key("password") {
205
6
                                data.insert("password".to_string(), Value::String("".to_string()));
206
6
                            }
207
                        }
208
                    }
209
12
                    Some(body)
210
                }
211
            };
212
30
            let req = Request::from_parts(parts, Body::from(body_bytes));
213

            
214
            // Do the request.
215
30
            let res = svc.call(req).await?;
216
30
            let (err_code, err_message) = match res.status().is_success() {
217
                false => {
218
                    // TODO: extract (code, message) pair of response body.
219
18
                    (None, None)
220
                }
221
12
                true => (None, None),
222
            };
223

            
224
            // Send log.
225
30
            let auth_token = match auth_token.as_ref() {
226
6
                None => return Ok(res),
227
24
                Some(auth_token) => auth_token,
228
            };
229
24
            let token_info = match get_token(client, auth_token.as_str(), auth_uri.as_str()).await {
230
12
                Err(_) => return Ok(res),
231
12
                Ok(token_info) => token_info,
232
            };
233
12
            let res_time = Utc::now();
234
12
            let msg = SendDataMsg {
235
12
                kind: DataMsgKind::OPERATION.to_string(),
236
12
                data: SendDataKind::Operation {
237
12
                    data_id: strings::random_id(&req_time, DATA_ID_RAND_LEN),
238
12
                    req_time: strings::time_str(&req_time),
239
12
                    res_time: strings::time_str(&res_time),
240
12
                    latency_ms: res_time.timestamp_millis() - req_time.timestamp_millis(),
241
12
                    status: res.status().as_u16() as isize,
242
12
                    source_ip,
243
12
                    method,
244
12
                    path,
245
12
                    body: log_body,
246
12
                    user_id: token_info.data.user_id,
247
12
                    client_id: token_info.data.client_id,
248
12
                    err_code,
249
12
                    err_message,
250
12
                },
251
12
            };
252
12
            if let Ok(payload) = serde_json::to_vec(&msg) {
253
12
                let _ = q.send_msg(payload).await;
254
            }
255
12
            Ok(res)
256
1100
        })
257
1100
    }
258
}
259

            
260
24
async fn get_token(
261
24
    client: reqwest::Client,
262
24
    auth_token: &str,
263
24
    auth_uri: &str,
264
24
) -> Result<GetTokenInfo, String> {
265
24
    let token_req = match client
266
24
        .request(reqwest::Method::GET, auth_uri)
267
24
        .header(reqwest::header::AUTHORIZATION, auth_token)
268
24
        .build()
269
    {
270
        Err(e) => return Err(format!("request auth error: {}", e)),
271
24
        Ok(req) => req,
272
    };
273
24
    let resp = match client.execute(token_req).await {
274
        Err(e) => return Err(format!("auth error: {}", e)),
275
24
        Ok(resp) => match resp.status() {
276
12
            reqwest::StatusCode::UNAUTHORIZED => return Err("unauthorized".to_string()),
277
12
            reqwest::StatusCode::OK => resp,
278
            _ => return Err(format!("auth error with status code: {}", resp.status())),
279
        },
280
    };
281
12
    match resp.json::<GetTokenInfo>().await {
282
        Err(e) => Err(format!("read auth body error: {}", e)),
283
12
        Ok(info) => Ok(info),
284
    }
285
24
}