use super::AuthenticationProvider; use axum::{extract::Request, response::Response}; use futures_core::future::BoxFuture; use headers::{Authorization, HeaderMapExt, authorization::Basic}; use std::{ sync::Arc, task::{Context, Poll}, }; use tower::{Layer, Service}; use tower_sessions::Session; use tracing::{Instrument, info_span}; pub struct AuthenticationLayer { auth_provider: Arc, } impl Clone for AuthenticationLayer { fn clone(&self) -> Self { Self { auth_provider: self.auth_provider.clone(), } } } impl AuthenticationLayer { pub fn new(auth_provider: Arc) -> Self { Self { auth_provider } } } impl Layer for AuthenticationLayer { type Service = AuthenticationMiddleware; fn layer(&self, inner: S) -> Self::Service { Self::Service { inner, auth_provider: self.auth_provider.clone(), } } } pub struct AuthenticationMiddleware { inner: S, auth_provider: Arc, } impl Clone for AuthenticationMiddleware { fn clone(&self) -> Self { Self { inner: self.inner.clone(), auth_provider: self.auth_provider.clone(), } } } impl Service for AuthenticationMiddleware where S: Service + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut request: Request) -> Self::Future { let auth_header: Option> = request.headers().typed_get(); let ap = self.auth_provider.clone(); let mut inner = self.inner.clone(); Box::pin(async move { if let Some(session) = request.extensions().get::() && let Ok(Some(user_id)) = session.get::("user").await && let Ok(Some(user)) = ap.get_principal(&user_id).await { request.extensions_mut().insert(user); } if let Some(auth) = auth_header { let user_id = auth.username(); let password = auth.password(); if let Ok(Some(user)) = ap .validate_app_token(user_id, password) .instrument(info_span!("validate_user_token")) .await { request.extensions_mut().insert(user); } } let response = inner.call(request).await?; Ok(response) }) } }