diff --git a/crates/xml/derive/src/xml_enum.rs b/crates/xml/derive/src/xml_enum.rs index 0876e9b..3d2e218 100644 --- a/crates/xml/derive/src/xml_enum.rs +++ b/crates/xml/derive/src/xml_enum.rs @@ -240,87 +240,164 @@ impl Enum { } pub fn impl_enum_unit_variants(&self) -> proc_macro2::TokenStream { - let ident = &self.ident; - if self.attrs.untagged.is_present() { - panic!("EnumUnitVariants not implemented for untagged enums"); - } let unit_enum_ident = self .attrs .unit_variants_ident .as_ref() .expect("unit_variants_ident no set"); + let ident = &self.ident; - let tagged_variants: Vec<_> = self - .variants - .iter() - .filter(|variant| !variant.attrs.other.is_present()) - .collect(); + if self.attrs.untagged.is_present() { + let variant_branches: Vec<_> = self + .variants + .iter() + .map(|variant| { + let variant_type = variant.deserializer_type(); + let variant_ident = &variant.variant.ident; + quote! { + #variant_ident (<#variant_type as ::rustical_xml::EnumUnitVariants>::UnitVariants) + } + }) + .collect(); - let variant_outputs: Vec<_> = tagged_variants - .iter() - .map(|variant| { - let ns = match &variant.attrs.common.ns { - Some(ns) => quote! { Some(#ns) }, - None => quote! { None }, - }; + let variant_idents: Vec<_> = self + .variants + .iter() + .map(|variant| &variant.variant.ident) + .collect(); + + let unit_to_output_branches = variant_idents.iter().map(|variant_ident| { + quote! { #unit_enum_ident::#variant_ident(val) => val.into() } + }); + + let str_to_unit_branches = self.variants.iter().map(|variant| { + let variant_type = variant.deserializer_type(); + let variant_ident = &variant.variant.ident; + quote! { + if let Ok(name) = <#variant_type as ::rustical_xml::EnumUnitVariants>::UnitVariants::from_str(val) { + return Ok(Self::#variant_ident(name)) + } + } + }); + + let from_enum_to_unit_branches = variant_idents.iter().map(|variant_ident| { + quote! { #ident::#variant_ident(val) => #unit_enum_ident::#variant_ident(val.into()) } + }); + + quote! { + #[derive(Clone, Debug, PartialEq)] + pub enum #unit_enum_ident { + #(#variant_branches),* + } + + impl ::rustical_xml::EnumUnitVariants for #ident { + type UnitVariants = #unit_enum_ident; + } + + impl From<#unit_enum_ident> for (Option<::quick_xml::name::Namespace<'static>>, &'static str) { + fn from(val: #unit_enum_ident) -> Self { + match val { + #(#unit_to_output_branches),* + } + } + } + + impl From<#ident> for #unit_enum_ident { + fn from(val: #ident) -> Self { + match val { + #(#from_enum_to_unit_branches),* + } + } + } + + impl ::std::str::FromStr for #unit_enum_ident { + type Err = ::rustical_xml::FromStrError; + + fn from_str(val: &str) -> Result { + #(#str_to_unit_branches);* + Err(::rustical_xml::FromStrError) + } + } + } + } else { + let tagged_variants: Vec<_> = self + .variants + .iter() + .filter(|variant| !variant.attrs.other.is_present()) + .collect(); + + let variant_outputs: Vec<_> = tagged_variants + .iter() + .map(|variant| { + let ns = match &variant.attrs.common.ns { + Some(ns) => quote! { Some(#ns) }, + None => quote! { None }, + }; + let b_xml_name = variant.xml_name().value(); + let xml_name = String::from_utf8_lossy(&b_xml_name); + quote! {(#ns, #xml_name)} + }) + .collect(); + + let variant_idents: Vec<_> = tagged_variants + .iter() + .map(|variant| &variant.variant.ident) + .collect(); + + let unit_to_output_branches = + variant_idents + .iter() + .zip(&variant_outputs) + .map(|(variant_ident, out)| { + quote! { #unit_enum_ident::#variant_ident => #out } + }); + + let from_enum_to_unit_branches = variant_idents.iter().map(|variant_ident| { + quote! { #ident::#variant_ident { .. } => #unit_enum_ident::#variant_ident } + }); + + let str_to_unit_branches = tagged_variants.iter().map(|variant| { + let variant_ident = &variant.variant.ident; let b_xml_name = variant.xml_name().value(); let xml_name = String::from_utf8_lossy(&b_xml_name); - quote! {(#ns, #xml_name)} - }) - .collect(); + quote! { #xml_name => Ok(#unit_enum_ident::#variant_ident) } + }); - let variant_idents: Vec<_> = tagged_variants - .iter() - .map(|variant| &variant.variant.ident) - .collect(); + quote! { + #[derive(Clone, Debug, PartialEq)] + pub enum #unit_enum_ident { + #(#variant_idents),* + } - let unit_to_output_branches = - variant_idents - .iter() - .zip(&variant_outputs) - .map(|(variant_ident, out)| { - quote! { #unit_enum_ident::#variant_ident => #out } - }); - let from_enum_to_unit_branches = variant_idents.iter().map(|variant_ident| { - quote! { #ident::#variant_ident { .. } => #unit_enum_ident::#variant_ident } - }); + impl ::rustical_xml::EnumUnitVariants for #ident { + type UnitVariants = #unit_enum_ident; + } - let str_to_unit_branches = tagged_variants.iter().map(|variant| { - let variant_ident = &variant.variant.ident; - let b_xml_name = variant.xml_name().value(); - let xml_name = String::from_utf8_lossy(&b_xml_name); - quote! { #xml_name => Ok(#unit_enum_ident::#variant_ident) } - }); - - quote! { - pub enum #unit_enum_ident { - #(#variant_idents),* - } - - impl From<#unit_enum_ident> for (Option<::quick_xml::name::Namespace<'static>>, &'static str) { - fn from(val: #unit_enum_ident) -> Self { - match val { - #(#unit_to_output_branches),* + impl From<#unit_enum_ident> for (Option<::quick_xml::name::Namespace<'static>>, &'static str) { + fn from(val: #unit_enum_ident) -> Self { + match val { + #(#unit_to_output_branches),* + } } } - } - impl From<#ident> for #unit_enum_ident { - fn from(val: #ident) -> Self { - match val { - #(#from_enum_to_unit_branches),* + impl From<#ident> for #unit_enum_ident { + fn from(val: #ident) -> Self { + match val { + #(#from_enum_to_unit_branches),* + } } } - } - impl ::std::str::FromStr for #unit_enum_ident { - type Err = ::rustical_xml::FromStrError; + impl ::std::str::FromStr for #unit_enum_ident { + type Err = ::rustical_xml::FromStrError; - fn from_str(val: &str) -> Result { - match val { - #(#str_to_unit_branches),*, - _ => Err(::rustical_xml::FromStrError) + fn from_str(val: &str) -> Result { + match val { + #(#str_to_unit_branches),*, + _ => Err(::rustical_xml::FromStrError) + } } } } diff --git a/crates/xml/src/lib.rs b/crates/xml/src/lib.rs index a46e9b3..9eca8cf 100644 --- a/crates/xml/src/lib.rs +++ b/crates/xml/src/lib.rs @@ -33,3 +33,7 @@ pub trait EnumVariants { // Returns all valid xml names including untagged variants fn variant_names() -> Vec<(Option>, &'static str)>; } + +pub trait EnumUnitVariants { + type UnitVariants; +} diff --git a/crates/xml/tests/enum_variants.rs b/crates/xml/tests/enum_variants.rs index 83d05ae..5518bf0 100644 --- a/crates/xml/tests/enum_variants.rs +++ b/crates/xml/tests/enum_variants.rs @@ -12,7 +12,8 @@ pub const NS_ICAL: Namespace = Namespace(b"http://apple.com/ns/ical/"); pub const NS_CALENDARSERVER: Namespace = Namespace(b"http://calendarserver.org/ns/"); pub const NS_NEXTCLOUD: Namespace = Namespace(b"http://nextcloud.com/ns"); -#[derive(EnumVariants)] +#[derive(EnumVariants, EnumUnitVariants)] +#[xml(unit_variants_ident = "ExtensionsPropName")] enum ExtensionProp { Hello, } @@ -44,8 +45,8 @@ fn test_enum_tagged_variants() { ); } -#[derive(EnumVariants)] -#[xml(untagged)] +#[derive(EnumVariants, EnumUnitVariants)] +#[xml(untagged, unit_variants_ident = "UnionPropName")] enum UnionProp { Calendar(CalendarProp), Extension(ExtensionProp), @@ -77,5 +78,37 @@ fn test_enum_unit_variants() { assert_eq!(displayname, (Some(NS_DAV), "displayname")); let propname: CalendarPropName = FromStr::from_str("displayname").unwrap(); - assert_eq!(displayname, (Some(NS_DAV), "displayname")); + assert_eq!(propname, CalendarPropName::Displayname) +} + +#[test] +fn test_enum_unit_variants_untagged() { + let displayname: (Option, &str) = + UnionPropName::Calendar(CalendarPropName::Displayname).into(); + assert_eq!(displayname, (Some(NS_DAV), "displayname")); + let hello: (Option, &str) = + UnionPropName::Extension(ExtensionsPropName::Hello).into(); + assert_eq!(hello, (None, "hello")); + + let propname: UnionPropName = FromStr::from_str("displayname").unwrap(); + assert_eq!( + propname, + UnionPropName::Calendar(CalendarPropName::Displayname) + ); + let propname: UnionPropName = FromStr::from_str("hello").unwrap(); + assert_eq!( + propname, + UnionPropName::Extension(ExtensionsPropName::Hello) + ); + + let propname: UnionPropName = UnionProp::Calendar(CalendarProp::Displayname(None)).into(); + assert_eq!( + propname, + UnionPropName::Calendar(CalendarPropName::Displayname) + ); + let propname: UnionPropName = UnionProp::Extension(ExtensionProp::Hello).into(); + assert_eq!( + propname, + UnionPropName::Extension(ExtensionsPropName::Hello) + ); }