1
use std::{
2
    borrow::Cow,
3
    collections::{HashMap, HashSet},
4
    str,
5
    sync::Arc,
6
    task::{Context, Poll},
7
};
8

            
9
use axum::{
10
    extract::{FromRequest, Request},
11
    http::{header, Method, StatusCode},
12
    response::{IntoResponse, Response},
13
};
14
use futures::future::BoxFuture;
15
use oxide_auth::code_grant::resource::{Error as ResourceError, Request as OxideResourceRequest};
16
use oxide_auth_async::code_grant;
17
use tower::{Layer, Service};
18

            
19
use sylvia_iot_corelib::{err::ErrResp, http::parse_header_auth};
20

            
21
use super::endpoint::Endpoint;
22
use crate::models::{
23
    client::QueryCond as ClientQueryCond, user::QueryCond as UserQueryCond, Model,
24
};
25

            
26
pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
27
type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
28

            
29
#[derive(Clone)]
30
pub struct AuthService {
31
    model: Arc<dyn Model>,
32
    role_scopes: HashMap<Method, RoleScopeType>,
33
}
34

            
35
#[derive(Clone)]
36
pub struct AuthMiddleware<S> {
37
    endpoint: Endpoint,
38
    model: Arc<dyn Model>,
39
    role_scopes: HashMap<Method, RoleScopeInner>,
40
    service: S,
41
}
42

            
43
struct ResourceRequest {
44
    authorization: Option<String>,
45
}
46

            
47
impl AuthService {
48
10716
    pub fn new(model: &Arc<dyn Model>, role_scopes: HashMap<Method, RoleScopeType>) -> Self {
49
10716
        AuthService {
50
10716
            model: model.clone(),
51
10716
            role_scopes,
52
10716
        }
53
10716
    }
54
}
55

            
56
impl<S> Layer<S> for AuthService {
57
    type Service = AuthMiddleware<S>;
58

            
59
27284
    fn layer(&self, inner: S) -> Self::Service {
60
27284
        let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
61
50672
        for (k, (r, s)) in self.role_scopes.iter() {
62
50672
            role_scopes.insert(
63
50672
                k.clone(),
64
50672
                (
65
68180
                    r.iter().map(|&r| r).collect(),
66
50672
                    s.iter().map(|s| s.clone()).collect(),
67
50672
                ),
68
50672
            );
69
50672
        }
70
27284
        AuthMiddleware {
71
27284
            endpoint: Endpoint::new(self.model.clone(), Some("")),
72
27284
            model: self.model.clone(),
73
27284
            role_scopes,
74
27284
            service: inner,
75
27284
        }
76
27284
    }
77
}
78

            
79
impl<S> Service<Request> for AuthMiddleware<S>
80
where
81
    S: Service<Request, Response = Response> + Clone + Send + 'static,
82
    S::Future: Send + 'static,
83
{
84
    type Response = S::Response;
85
    type Error = S::Error;
86
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
87

            
88
500
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89
500
        self.service.poll_ready(cx)
90
500
    }
91

            
92
500
    fn call(&mut self, mut req: Request) -> Self::Future {
93
500
        let mut svc = self.service.clone();
94
500
        let mut endpoint = self.endpoint.clone();
95
500
        let model = self.model.clone();
96
500
        let role_scopes = self.role_scopes.clone();
97
500

            
98
500
        Box::pin(async move {
99
500
            let auth_req = match ResourceRequest::new(&req) {
100
                Err(e) => return Ok(e.into_response()),
101
500
                Ok(req) => match req.token().is_none() {
102
498
                    false => req,
103
                    true => {
104
2
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
105
2
                        return Ok(e.into_response());
106
                    }
107
                },
108
            };
109
498
            let grant = match code_grant::resource::protect(&mut endpoint, &auth_req).await {
110
42
                Err(e) => match e {
111
                    ResourceError::PrimitiveError => {
112
                        return Ok(ErrResp::ErrDb(None).into_response());
113
                    }
114
                    _ => {
115
42
                        return Ok((
116
42
                            StatusCode::UNAUTHORIZED,
117
42
                            [(header::WWW_AUTHENTICATE, e.www_authenticate())],
118
42
                        )
119
42
                            .into_response());
120
                    }
121
                },
122
456
                Ok(grant) => grant,
123
456
            };
124
456

            
125
456
            let cond = UserQueryCond {
126
456
                user_id: Some(grant.owner_id.as_str()),
127
456
                account: None,
128
456
            };
129
456
            let user = match model.user().get(&cond).await {
130
                Err(e) => {
131
                    return Ok(ErrResp::ErrDb(Some(e.to_string())).into_response());
132
                }
133
456
                Ok(user) => match user {
134
                    None => {
135
                        let e = ErrResp::ErrPerm(Some("user not exist".to_string()));
136
                        return Ok(e.into_response());
137
                    }
138
456
                    Some(user) => {
139
456
                        if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
140
456
                            if api_roles.len() > 0 {
141
390
                                let roles: HashSet<&str> = user
142
390
                                    .roles
143
390
                                    .iter()
144
390
                                    .filter(|(_, &v)| v)
145
390
                                    .map(|(k, _)| k.as_str())
146
390
                                    .collect();
147
390
                                if api_roles.is_disjoint(&roles) {
148
72
                                    let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
149
72
                                    return Ok(e.into_response());
150
318
                                }
151
66
                            }
152
384
                            if api_scopes.len() > 0 {
153
4
                                let api_scopes: HashSet<&str> =
154
4
                                    api_scopes.iter().map(|s| s.as_str()).collect();
155
4
                                let scopes: HashSet<&str> = grant.scope.iter().map(|s| s).collect();
156
4
                                if api_scopes.is_disjoint(&scopes) {
157
2
                                    return Ok(ErrResp::ErrPerm(Some("invalid scope".to_string()))
158
2
                                        .into_response());
159
2
                                }
160
380
                            }
161
                        }
162
382
                        user
163
382
                    }
164
382
                },
165
382
            };
166
382
            req.extensions_mut().insert(user);
167
382

            
168
382
            let cond = ClientQueryCond {
169
382
                client_id: Some(grant.client_id.as_str()),
170
382
                ..Default::default()
171
382
            };
172
382
            let client = match model.client().get(&cond).await {
173
                Err(e) => {
174
                    return Ok(ErrResp::ErrDb(Some(e.to_string())).into_response());
175
                }
176
382
                Ok(client) => match client {
177
                    None => {
178
                        let e = ErrResp::ErrPerm(Some("client not exist".to_string()));
179
                        return Ok(e.into_response());
180
                    }
181
382
                    Some(client) => client,
182
382
                },
183
382
            };
184
382
            req.extensions_mut().insert(client);
185

            
186
382
            let res = svc.call(req).await?;
187
382
            Ok(res)
188
500
        })
189
500
    }
190
}
191

            
192
impl ResourceRequest {
193
500
    fn new(req: &Request) -> Result<Self, ErrResp> {
194
500
        match parse_header_auth(req) {
195
            Err(e) => Err(e),
196
500
            Ok(auth) => Ok(ResourceRequest {
197
500
                authorization: auth,
198
500
            }),
199
        }
200
500
    }
201
}
202

            
203
impl<S> FromRequest<S> for ResourceRequest
204
where
205
    S: Send + Sync,
206
{
207
    type Rejection = Response;
208

            
209
    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
210
        match parse_header_auth(&req) {
211
            Err(e) => Err(e.into_response()),
212
            Ok(auth) => Ok(ResourceRequest {
213
                authorization: auth,
214
            }),
215
        }
216
    }
217
}
218

            
219
impl OxideResourceRequest for ResourceRequest {
220
498
    fn valid(&self) -> bool {
221
498
        true
222
498
    }
223

            
224
998
    fn token(&self) -> Option<Cow<str>> {
225
998
        match self.authorization.as_deref() {
226
2
            None => None,
227
996
            Some(auth) => Some(Cow::from(auth)),
228
        }
229
998
    }
230
}