1
use std::{
2
    borrow::Cow,
3
    collections::{HashMap, HashSet},
4
    str,
5
    sync::Arc,
6
    task::{Context, Poll},
7
};
8

            
9
use axum::{
10
    extract::{FromRequest, Request},
11
    http::{Method, StatusCode, header},
12
    response::{IntoResponse, Response},
13
};
14
use futures::future::BoxFuture;
15
use oxide_auth::code_grant::resource::{Error as ResourceError, Request as OxideResourceRequest};
16
use oxide_auth_async::code_grant;
17
use tower::{Layer, Service};
18

            
19
use sylvia_iot_corelib::{err::ErrResp, http::parse_header_auth};
20

            
21
use super::endpoint::Endpoint;
22
use crate::models::{
23
    Model, client::QueryCond as ClientQueryCond, user::QueryCond as UserQueryCond,
24
};
25

            
26
pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
27
type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
28

            
29
#[derive(Clone)]
30
pub struct AuthService {
31
    model: Arc<dyn Model>,
32
    role_scopes: HashMap<Method, RoleScopeType>,
33
}
34

            
35
#[derive(Clone)]
36
pub struct AuthMiddleware<S> {
37
    endpoint: Endpoint,
38
    model: Arc<dyn Model>,
39
    role_scopes: HashMap<Method, RoleScopeInner>,
40
    service: S,
41
}
42

            
43
struct ResourceRequest {
44
    authorization: Option<String>,
45
}
46

            
47
impl AuthService {
48
21432
    pub fn new(model: &Arc<dyn Model>, role_scopes: HashMap<Method, RoleScopeType>) -> Self {
49
21432
        AuthService {
50
21432
            model: model.clone(),
51
21432
            role_scopes,
52
21432
        }
53
21432
    }
54
}
55

            
56
impl<S> Layer<S> for AuthService {
57
    type Service = AuthMiddleware<S>;
58

            
59
54568
    fn layer(&self, inner: S) -> Self::Service {
60
54568
        let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
61
101344
        for (k, (r, s)) in self.role_scopes.iter() {
62
101344
            role_scopes.insert(
63
101344
                k.clone(),
64
101344
                (
65
136360
                    r.iter().map(|&r| r).collect(),
66
101344
                    s.iter().map(|s| s.clone()).collect(),
67
101344
                ),
68
101344
            );
69
101344
        }
70
54568
        AuthMiddleware {
71
54568
            endpoint: Endpoint::new(self.model.clone(), Some("")),
72
54568
            model: self.model.clone(),
73
54568
            role_scopes,
74
54568
            service: inner,
75
54568
        }
76
54568
    }
77
}
78

            
79
impl<S> Service<Request> for AuthMiddleware<S>
80
where
81
    S: Service<Request, Response = Response> + Clone + Send + 'static,
82
    S::Future: Send + 'static,
83
{
84
    type Response = S::Response;
85
    type Error = S::Error;
86
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
87

            
88
1000
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89
1000
        self.service.poll_ready(cx)
90
1000
    }
91

            
92
1000
    fn call(&mut self, mut req: Request) -> Self::Future {
93
1000
        let mut svc = self.service.clone();
94
1000
        let mut endpoint = self.endpoint.clone();
95
1000
        let model = self.model.clone();
96
1000
        let role_scopes = self.role_scopes.clone();
97
1000

            
98
1000
        Box::pin(async move {
99
1000
            let auth_req = match ResourceRequest::new(&req) {
100
                Err(e) => return Ok(e.into_response()),
101
1000
                Ok(req) => match req.token().is_none() {
102
996
                    false => req,
103
                    true => {
104
4
                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
105
4
                        return Ok(e.into_response());
106
                    }
107
                },
108
            };
109
996
            let grant = match code_grant::resource::protect(&mut endpoint, &auth_req).await {
110
84
                Err(e) => match e {
111
                    ResourceError::PrimitiveError => {
112
                        return Ok(ErrResp::ErrDb(None).into_response());
113
                    }
114
                    _ => {
115
84
                        return Ok((
116
84
                            StatusCode::UNAUTHORIZED,
117
84
                            [(header::WWW_AUTHENTICATE, e.www_authenticate())],
118
84
                        )
119
84
                            .into_response());
120
                    }
121
                },
122
912
                Ok(grant) => grant,
123
912
            };
124
912

            
125
912
            let cond = UserQueryCond {
126
912
                user_id: Some(grant.owner_id.as_str()),
127
912
                account: None,
128
912
            };
129
912
            let user = match model.user().get(&cond).await {
130
                Err(e) => {
131
                    return Ok(ErrResp::ErrDb(Some(e.to_string())).into_response());
132
                }
133
912
                Ok(user) => match user {
134
                    None => {
135
                        let e = ErrResp::ErrPerm(Some("user not exist".to_string()));
136
                        return Ok(e.into_response());
137
                    }
138
912
                    Some(user) => {
139
912
                        if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
140
912
                            if api_roles.len() > 0 {
141
780
                                let roles: HashSet<&str> = user
142
780
                                    .roles
143
780
                                    .iter()
144
780
                                    .filter(|(_, v)| **v)
145
780
                                    .map(|(k, _)| k.as_str())
146
780
                                    .collect();
147
780
                                if api_roles.is_disjoint(&roles) {
148
144
                                    let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
149
144
                                    return Ok(e.into_response());
150
636
                                }
151
132
                            }
152
768
                            if api_scopes.len() > 0 {
153
8
                                let api_scopes: HashSet<&str> =
154
8
                                    api_scopes.iter().map(|s| s.as_str()).collect();
155
8
                                let scopes: HashSet<&str> = grant.scope.iter().map(|s| s).collect();
156
8
                                if api_scopes.is_disjoint(&scopes) {
157
4
                                    return Ok(ErrResp::ErrPerm(Some("invalid scope".to_string()))
158
4
                                        .into_response());
159
4
                                }
160
760
                            }
161
                        }
162
764
                        user
163
764
                    }
164
764
                },
165
764
            };
166
764
            req.extensions_mut().insert(user);
167
764

            
168
764
            let cond = ClientQueryCond {
169
764
                client_id: Some(grant.client_id.as_str()),
170
764
                ..Default::default()
171
764
            };
172
764
            let client = match model.client().get(&cond).await {
173
                Err(e) => {
174
                    return Ok(ErrResp::ErrDb(Some(e.to_string())).into_response());
175
                }
176
764
                Ok(client) => match client {
177
                    None => {
178
                        let e = ErrResp::ErrPerm(Some("client not exist".to_string()));
179
                        return Ok(e.into_response());
180
                    }
181
764
                    Some(client) => client,
182
764
                },
183
764
            };
184
764
            req.extensions_mut().insert(client);
185

            
186
764
            let res = svc.call(req).await?;
187
764
            Ok(res)
188
1000
        })
189
1000
    }
190
}
191

            
192
impl ResourceRequest {
193
1000
    fn new(req: &Request) -> Result<Self, ErrResp> {
194
1000
        match parse_header_auth(req) {
195
            Err(e) => Err(e),
196
1000
            Ok(auth) => Ok(ResourceRequest {
197
1000
                authorization: auth,
198
1000
            }),
199
        }
200
1000
    }
201
}
202

            
203
impl<S> FromRequest<S> for ResourceRequest
204
where
205
    S: Send + Sync,
206
{
207
    type Rejection = Response;
208

            
209
    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
210
        match parse_header_auth(&req) {
211
            Err(e) => Err(e.into_response()),
212
            Ok(auth) => Ok(ResourceRequest {
213
                authorization: auth,
214
            }),
215
        }
216
    }
217
}
218

            
219
impl OxideResourceRequest for ResourceRequest {
220
996
    fn valid(&self) -> bool {
221
996
        true
222
996
    }
223

            
224
1996
    fn token(&self) -> Option<Cow<str>> {
225
1996
        match self.authorization.as_deref() {
226
4
            None => None,
227
1992
            Some(auth) => Some(Cow::from(auth)),
228
        }
229
1996
    }
230
}