diff --git a/crates/caldav/src/calendar/methods/import.rs b/crates/caldav/src/calendar/methods/import.rs index a909cf0..8cadeb6 100644 --- a/crates/caldav/src/calendar/methods/import.rs +++ b/crates/caldav/src/calendar/methods/import.rs @@ -22,7 +22,7 @@ pub async fn route_import( Path((principal, cal_id)): Path<(String, String)>, user: Principal, State(resource_service): State>, - overwrite: Overwrite, + Overwrite(overwrite): Overwrite, body: String, ) -> Result { if !user.is_principal(&principal) { @@ -103,7 +103,7 @@ pub async fn route_import( let cal_store = resource_service.cal_store; cal_store - .import_calendar(new_cal, objects, overwrite.is_true()) + .import_calendar(new_cal, objects, overwrite) .await?; Ok(StatusCode::OK.into_response()) diff --git a/crates/dav/src/header/overwrite.rs b/crates/dav/src/header/overwrite.rs index 29ae8d8..06cbb9e 100644 --- a/crates/dav/src/header/overwrite.rs +++ b/crates/dav/src/header/overwrite.rs @@ -14,16 +14,12 @@ impl IntoResponse for InvalidOverwriteHeader { } } -#[derive(Debug, PartialEq, Default)] -pub enum Overwrite { - #[default] - T, - F, -} +#[derive(Debug, PartialEq)] +pub struct Overwrite(pub bool); -impl Overwrite { - pub fn is_true(&self) -> bool { - matches!(self, Self::T) +impl Default for Overwrite { + fn default() -> Self { + Self(true) } } @@ -47,9 +43,48 @@ impl TryFrom<&[u8]> for Overwrite { fn try_from(value: &[u8]) -> Result { match value { - b"T" => Ok(Overwrite::T), - b"F" => Ok(Overwrite::F), + b"T" => Ok(Self(true)), + b"F" => Ok(Self(false)), _ => Err(InvalidOverwriteHeader), } } } + +#[cfg(test)] +mod tests { + use axum::{extract::FromRequestParts, response::IntoResponse}; + use http::Request; + + use crate::header::Overwrite; + + #[tokio::test] + async fn test_overwrite_default() { + let request = Request::put("asd").body(()).unwrap(); + let (mut parts, _) = request.into_parts(); + let overwrite = Overwrite::from_request_parts(&mut parts, &()) + .await + .unwrap(); + assert_eq!( + Overwrite(true), + overwrite, + "By default we want to overwrite!" + ); + } + + #[test] + fn test_overwrite() { + assert_eq!( + Overwrite(true), + Overwrite::try_from(b"T".as_slice()).unwrap() + ); + assert_eq!( + Overwrite(false), + Overwrite::try_from(b"F".as_slice()).unwrap() + ); + if let Err(err) = Overwrite::try_from(b"aslkdjlad".as_slice()) { + let _ = err.into_response(); + } else { + unreachable!("should return error") + } + } +} diff --git a/crates/dav/src/resource/methods/copy.rs b/crates/dav/src/resource/methods/copy.rs index eb1b966..daa11a8 100644 --- a/crates/dav/src/resource/methods/copy.rs +++ b/crates/dav/src/resource/methods/copy.rs @@ -17,7 +17,7 @@ pub(crate) async fn axum_route_copy( State(resource_service): State, depth: Option, principal: R::Principal, - overwrite: Overwrite, + Overwrite(overwrite): Overwrite, matched_path: MatchedPath, header_map: HeaderMap, ) -> Result { @@ -39,7 +39,7 @@ pub(crate) async fn axum_route_copy( .map_err(|_| crate::Error::Forbidden)?; if resource_service - .copy_resource(&path, &dest_path, &principal, overwrite.is_true()) + .copy_resource(&path, &dest_path, &principal, overwrite) .await? { // Overwritten diff --git a/crates/dav/src/resource/methods/mv.rs b/crates/dav/src/resource/methods/mv.rs index 28e3568..a2ce1d4 100644 --- a/crates/dav/src/resource/methods/mv.rs +++ b/crates/dav/src/resource/methods/mv.rs @@ -17,7 +17,7 @@ pub(crate) async fn axum_route_move( State(resource_service): State, depth: Option, principal: R::Principal, - overwrite: Overwrite, + Overwrite(overwrite): Overwrite, matched_path: MatchedPath, header_map: HeaderMap, ) -> Result { @@ -39,7 +39,7 @@ pub(crate) async fn axum_route_move( .map_err(|_| crate::Error::Forbidden)?; if resource_service - .copy_resource(&path, &dest_path, &principal, overwrite.is_true()) + .copy_resource(&path, &dest_path, &principal, overwrite) .await? { // Overwritten