1
use std::{
2
    borrow::Cow,
3
    collections::{HashMap, HashSet},
4
    str,
5
    sync::Arc,
6
    task::{Context, Poll},
7
};
8

            
9
use async_trait::async_trait;
10
use axum::{
11
    extract::{FromRequest, Request},
12
    http::{header, Method, StatusCode},
13
    response::{IntoResponse, Response},
14
};
15
use futures::future::BoxFuture;
16
use oxide_auth::code_grant::resource::{Error as ResourceError, Request as OxideResourceRequest};
17
use oxide_auth_async::code_grant;
18
use tower::{Layer, Service};
19

            
20
use sylvia_iot_corelib::{err::ErrResp, http::parse_header_auth};
21

            
22
use super::endpoint::Endpoint;
23
use crate::models::{
24
    client::QueryCond as ClientQueryCond, user::QueryCond as UserQueryCond, Model,
25
};
26

            
27
pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
28
type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
29

            
30
#[derive(Clone)]
31
pub struct AuthService {
32
    model: Arc<dyn Model>,
33
    role_scopes: HashMap<Method, RoleScopeType>,
34
}
35

            
36
#[derive(Clone)]
37
pub struct AuthMiddleware<S> {
38
    endpoint: Endpoint,
39
    model: Arc<dyn Model>,
40
    role_scopes: HashMap<Method, RoleScopeInner>,
41
    service: S,
42
}
43

            
44
struct ResourceRequest {
45
    authorization: Option<String>,
46
}
47

            
48
impl AuthService {
49
10716
    pub fn new(model: &Arc<dyn Model>, role_scopes: HashMap<Method, RoleScopeType>) -> Self {
50
10716
        AuthService {
51
10716
            model: model.clone(),
52
10716
            role_scopes,
53
10716
        }
54
10716
    }
55
}
56

            
57
impl<S> Layer<S> for AuthService {
58
    type Service = AuthMiddleware<S>;
59

            
60
27284
    fn layer(&self, inner: S) -> Self::Service {
61
27284
        let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
62
50672
        for (k, (r, s)) in self.role_scopes.iter() {
63
50672
            role_scopes.insert(
64
50672
                k.clone(),
65
50672
                (
66
68180
                    r.iter().map(|&r| r).collect(),
67
50672
                    s.iter().map(|s| s.clone()).collect(),
68
50672
                ),
69
50672
            );
70
50672
        }
71
27284
        AuthMiddleware {
72
27284
            endpoint: Endpoint::new(self.model.clone(), Some("")),
73
27284
            model: self.model.clone(),
74
27284
            role_scopes,
75
27284
            service: inner,
76
27284
        }
77
27284
    }
78
}
79

            
80
impl<S> Service<Request> for AuthMiddleware<S>
81
where
82
    S: Service<Request, Response = Response> + Clone + Send + 'static,
83
    S::Future: Send + 'static,
84
{
85
    type Response = S::Response;
86
    type Error = S::Error;
87
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
88

            
89
500
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90
500
        self.service.poll_ready(cx)
91
500
    }
92

            
93
500
    fn call(&mut self, mut req: Request) -> Self::Future {
94
500
        let mut svc = self.service.clone();
95
500
        let mut endpoint = self.endpoint.clone();
96
500
        let model = self.model.clone();
97
500
        let role_scopes = self.role_scopes.clone();
98
500

            
99
500
        Box::pin(async move {
100
500
            let auth_req = match ResourceRequest::new(&req) {
101
                Err(e) => return Ok(e.into_response()),
102
500
                Ok(req) => match req.token().is_none() {
103
498
                    false => req,
104
                    true => {
105
2
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
106
2
                        return Ok(e.into_response());
107
                    }
108
                },
109
            };
110
996
            let grant = match code_grant::resource::protect(&mut endpoint, &auth_req).await {
111
42
                Err(e) => match e {
112
                    ResourceError::PrimitiveError => {
113
                        return Ok(ErrResp::ErrDb(None).into_response());
114
                    }
115
                    _ => {
116
42
                        return Ok((
117
42
                            StatusCode::UNAUTHORIZED,
118
42
                            [(header::WWW_AUTHENTICATE, e.www_authenticate())],
119
42
                        )
120
42
                            .into_response());
121
                    }
122
                },
123
456
                Ok(grant) => grant,
124
456
            };
125
456

            
126
456
            let cond = UserQueryCond {
127
456
                user_id: Some(grant.owner_id.as_str()),
128
456
                account: None,
129
456
            };
130
911
            let user = match model.user().get(&cond).await {
131
                Err(e) => {
132
                    return Ok(ErrResp::ErrDb(Some(e.to_string())).into_response());
133
                }
134
456
                Ok(user) => match user {
135
                    None => {
136
                        let e = ErrResp::ErrPerm(Some("user not exist".to_string()));
137
                        return Ok(e.into_response());
138
                    }
139
456
                    Some(user) => {
140
456
                        if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
141
456
                            if api_roles.len() > 0 {
142
390
                                let roles: HashSet<&str> = user
143
390
                                    .roles
144
390
                                    .iter()
145
390
                                    .filter(|(_, &v)| v)
146
390
                                    .map(|(k, _)| k.as_str())
147
390
                                    .collect();
148
390
                                if api_roles.is_disjoint(&roles) {
149
72
                                    let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
150
72
                                    return Ok(e.into_response());
151
318
                                }
152
66
                            }
153
384
                            if api_scopes.len() > 0 {
154
4
                                let api_scopes: HashSet<&str> =
155
4
                                    api_scopes.iter().map(|s| s.as_str()).collect();
156
4
                                let scopes: HashSet<&str> = grant.scope.iter().map(|s| s).collect();
157
4
                                if api_scopes.is_disjoint(&scopes) {
158
2
                                    return Ok(ErrResp::ErrPerm(Some("invalid scope".to_string()))
159
2
                                        .into_response());
160
2
                                }
161
380
                            }
162
                        }
163
382
                        user
164
382
                    }
165
382
                },
166
382
            };
167
382
            req.extensions_mut().insert(user);
168
382

            
169
382
            let cond = ClientQueryCond {
170
382
                client_id: Some(grant.client_id.as_str()),
171
382
                ..Default::default()
172
382
            };
173
764
            let client = match model.client().get(&cond).await {
174
                Err(e) => {
175
                    return Ok(ErrResp::ErrDb(Some(e.to_string())).into_response());
176
                }
177
382
                Ok(client) => match client {
178
                    None => {
179
                        let e = ErrResp::ErrPerm(Some("client not exist".to_string()));
180
                        return Ok(e.into_response());
181
                    }
182
382
                    Some(client) => client,
183
382
                },
184
382
            };
185
382
            req.extensions_mut().insert(client);
186

            
187
708
            let res = svc.call(req).await?;
188
382
            Ok(res)
189
500
        })
190
500
    }
191
}
192

            
193
impl ResourceRequest {
194
500
    fn new(req: &Request) -> Result<Self, ErrResp> {
195
500
        match parse_header_auth(req) {
196
            Err(e) => Err(e),
197
500
            Ok(auth) => Ok(ResourceRequest {
198
500
                authorization: auth,
199
500
            }),
200
        }
201
500
    }
202
}
203

            
204
#[async_trait]
205
impl<S> FromRequest<S> for ResourceRequest
206
where
207
    S: Send + Sync,
208
{
209
    type Rejection = Response;
210

            
211
    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
212
        match parse_header_auth(&req) {
213
            Err(e) => Err(e.into_response()),
214
            Ok(auth) => Ok(ResourceRequest {
215
                authorization: auth,
216
            }),
217
        }
218
    }
219
}
220

            
221
impl OxideResourceRequest for ResourceRequest {
222
498
    fn valid(&self) -> bool {
223
498
        true
224
498
    }
225

            
226
998
    fn token(&self) -> Option<Cow<str>> {
227
998
        match self.authorization.as_deref() {
228
2
            None => None,
229
996
            Some(auth) => Some(Cow::from(auth)),
230
        }
231
998
    }
232
}