xml: EnumUnitVariants support untagged enums

This commit is contained in:
Lennart
2025-01-18 21:51:30 +01:00
parent e9610dc974
commit 8d7574290c
3 changed files with 179 additions and 65 deletions

View File

@@ -240,87 +240,164 @@ impl Enum {
} }
pub fn impl_enum_unit_variants(&self) -> proc_macro2::TokenStream { pub fn impl_enum_unit_variants(&self) -> proc_macro2::TokenStream {
let ident = &self.ident;
if self.attrs.untagged.is_present() {
panic!("EnumUnitVariants not implemented for untagged enums");
}
let unit_enum_ident = self let unit_enum_ident = self
.attrs .attrs
.unit_variants_ident .unit_variants_ident
.as_ref() .as_ref()
.expect("unit_variants_ident no set"); .expect("unit_variants_ident no set");
let ident = &self.ident;
let tagged_variants: Vec<_> = self if self.attrs.untagged.is_present() {
.variants let variant_branches: Vec<_> = self
.iter() .variants
.filter(|variant| !variant.attrs.other.is_present()) .iter()
.collect(); .map(|variant| {
let variant_type = variant.deserializer_type();
let variant_ident = &variant.variant.ident;
quote! {
#variant_ident (<#variant_type as ::rustical_xml::EnumUnitVariants>::UnitVariants)
}
})
.collect();
let variant_outputs: Vec<_> = tagged_variants let variant_idents: Vec<_> = self
.iter() .variants
.map(|variant| { .iter()
let ns = match &variant.attrs.common.ns { .map(|variant| &variant.variant.ident)
Some(ns) => quote! { Some(#ns) }, .collect();
None => quote! { None },
}; let unit_to_output_branches = variant_idents.iter().map(|variant_ident| {
quote! { #unit_enum_ident::#variant_ident(val) => val.into() }
});
let str_to_unit_branches = self.variants.iter().map(|variant| {
let variant_type = variant.deserializer_type();
let variant_ident = &variant.variant.ident;
quote! {
if let Ok(name) = <#variant_type as ::rustical_xml::EnumUnitVariants>::UnitVariants::from_str(val) {
return Ok(Self::#variant_ident(name))
}
}
});
let from_enum_to_unit_branches = variant_idents.iter().map(|variant_ident| {
quote! { #ident::#variant_ident(val) => #unit_enum_ident::#variant_ident(val.into()) }
});
quote! {
#[derive(Clone, Debug, PartialEq)]
pub enum #unit_enum_ident {
#(#variant_branches),*
}
impl ::rustical_xml::EnumUnitVariants for #ident {
type UnitVariants = #unit_enum_ident;
}
impl From<#unit_enum_ident> for (Option<::quick_xml::name::Namespace<'static>>, &'static str) {
fn from(val: #unit_enum_ident) -> Self {
match val {
#(#unit_to_output_branches),*
}
}
}
impl From<#ident> for #unit_enum_ident {
fn from(val: #ident) -> Self {
match val {
#(#from_enum_to_unit_branches),*
}
}
}
impl ::std::str::FromStr for #unit_enum_ident {
type Err = ::rustical_xml::FromStrError;
fn from_str(val: &str) -> Result<Self, Self::Err> {
#(#str_to_unit_branches);*
Err(::rustical_xml::FromStrError)
}
}
}
} else {
let tagged_variants: Vec<_> = self
.variants
.iter()
.filter(|variant| !variant.attrs.other.is_present())
.collect();
let variant_outputs: Vec<_> = tagged_variants
.iter()
.map(|variant| {
let ns = match &variant.attrs.common.ns {
Some(ns) => quote! { Some(#ns) },
None => quote! { None },
};
let b_xml_name = variant.xml_name().value();
let xml_name = String::from_utf8_lossy(&b_xml_name);
quote! {(#ns, #xml_name)}
})
.collect();
let variant_idents: Vec<_> = tagged_variants
.iter()
.map(|variant| &variant.variant.ident)
.collect();
let unit_to_output_branches =
variant_idents
.iter()
.zip(&variant_outputs)
.map(|(variant_ident, out)| {
quote! { #unit_enum_ident::#variant_ident => #out }
});
let from_enum_to_unit_branches = variant_idents.iter().map(|variant_ident| {
quote! { #ident::#variant_ident { .. } => #unit_enum_ident::#variant_ident }
});
let str_to_unit_branches = tagged_variants.iter().map(|variant| {
let variant_ident = &variant.variant.ident;
let b_xml_name = variant.xml_name().value(); let b_xml_name = variant.xml_name().value();
let xml_name = String::from_utf8_lossy(&b_xml_name); let xml_name = String::from_utf8_lossy(&b_xml_name);
quote! {(#ns, #xml_name)} quote! { #xml_name => Ok(#unit_enum_ident::#variant_ident) }
}) });
.collect();
let variant_idents: Vec<_> = tagged_variants quote! {
.iter() #[derive(Clone, Debug, PartialEq)]
.map(|variant| &variant.variant.ident) pub enum #unit_enum_ident {
.collect(); #(#variant_idents),*
}
let unit_to_output_branches =
variant_idents
.iter()
.zip(&variant_outputs)
.map(|(variant_ident, out)| {
quote! { #unit_enum_ident::#variant_ident => #out }
});
let from_enum_to_unit_branches = variant_idents.iter().map(|variant_ident| { impl ::rustical_xml::EnumUnitVariants for #ident {
quote! { #ident::#variant_ident { .. } => #unit_enum_ident::#variant_ident } type UnitVariants = #unit_enum_ident;
}); }
let str_to_unit_branches = tagged_variants.iter().map(|variant| { impl From<#unit_enum_ident> for (Option<::quick_xml::name::Namespace<'static>>, &'static str) {
let variant_ident = &variant.variant.ident; fn from(val: #unit_enum_ident) -> Self {
let b_xml_name = variant.xml_name().value(); match val {
let xml_name = String::from_utf8_lossy(&b_xml_name); #(#unit_to_output_branches),*
quote! { #xml_name => Ok(#unit_enum_ident::#variant_ident) } }
});
quote! {
pub enum #unit_enum_ident {
#(#variant_idents),*
}
impl From<#unit_enum_ident> for (Option<::quick_xml::name::Namespace<'static>>, &'static str) {
fn from(val: #unit_enum_ident) -> Self {
match val {
#(#unit_to_output_branches),*
} }
} }
}
impl From<#ident> for #unit_enum_ident { impl From<#ident> for #unit_enum_ident {
fn from(val: #ident) -> Self { fn from(val: #ident) -> Self {
match val { match val {
#(#from_enum_to_unit_branches),* #(#from_enum_to_unit_branches),*
}
} }
} }
}
impl ::std::str::FromStr for #unit_enum_ident { impl ::std::str::FromStr for #unit_enum_ident {
type Err = ::rustical_xml::FromStrError; type Err = ::rustical_xml::FromStrError;
fn from_str(val: &str) -> Result<Self, Self::Err> { fn from_str(val: &str) -> Result<Self, Self::Err> {
match val { match val {
#(#str_to_unit_branches),*, #(#str_to_unit_branches),*,
_ => Err(::rustical_xml::FromStrError) _ => Err(::rustical_xml::FromStrError)
}
} }
} }
} }

View File

@@ -33,3 +33,7 @@ pub trait EnumVariants {
// Returns all valid xml names including untagged variants // Returns all valid xml names including untagged variants
fn variant_names() -> Vec<(Option<Namespace<'static>>, &'static str)>; fn variant_names() -> Vec<(Option<Namespace<'static>>, &'static str)>;
} }
pub trait EnumUnitVariants {
type UnitVariants;
}

View File

@@ -12,7 +12,8 @@ pub const NS_ICAL: Namespace = Namespace(b"http://apple.com/ns/ical/");
pub const NS_CALENDARSERVER: Namespace = Namespace(b"http://calendarserver.org/ns/"); pub const NS_CALENDARSERVER: Namespace = Namespace(b"http://calendarserver.org/ns/");
pub const NS_NEXTCLOUD: Namespace = Namespace(b"http://nextcloud.com/ns"); pub const NS_NEXTCLOUD: Namespace = Namespace(b"http://nextcloud.com/ns");
#[derive(EnumVariants)] #[derive(EnumVariants, EnumUnitVariants)]
#[xml(unit_variants_ident = "ExtensionsPropName")]
enum ExtensionProp { enum ExtensionProp {
Hello, Hello,
} }
@@ -44,8 +45,8 @@ fn test_enum_tagged_variants() {
); );
} }
#[derive(EnumVariants)] #[derive(EnumVariants, EnumUnitVariants)]
#[xml(untagged)] #[xml(untagged, unit_variants_ident = "UnionPropName")]
enum UnionProp { enum UnionProp {
Calendar(CalendarProp), Calendar(CalendarProp),
Extension(ExtensionProp), Extension(ExtensionProp),
@@ -77,5 +78,37 @@ fn test_enum_unit_variants() {
assert_eq!(displayname, (Some(NS_DAV), "displayname")); assert_eq!(displayname, (Some(NS_DAV), "displayname"));
let propname: CalendarPropName = FromStr::from_str("displayname").unwrap(); let propname: CalendarPropName = FromStr::from_str("displayname").unwrap();
assert_eq!(displayname, (Some(NS_DAV), "displayname")); assert_eq!(propname, CalendarPropName::Displayname)
}
#[test]
fn test_enum_unit_variants_untagged() {
let displayname: (Option<Namespace>, &str) =
UnionPropName::Calendar(CalendarPropName::Displayname).into();
assert_eq!(displayname, (Some(NS_DAV), "displayname"));
let hello: (Option<Namespace>, &str) =
UnionPropName::Extension(ExtensionsPropName::Hello).into();
assert_eq!(hello, (None, "hello"));
let propname: UnionPropName = FromStr::from_str("displayname").unwrap();
assert_eq!(
propname,
UnionPropName::Calendar(CalendarPropName::Displayname)
);
let propname: UnionPropName = FromStr::from_str("hello").unwrap();
assert_eq!(
propname,
UnionPropName::Extension(ExtensionsPropName::Hello)
);
let propname: UnionPropName = UnionProp::Calendar(CalendarProp::Displayname(None)).into();
assert_eq!(
propname,
UnionPropName::Calendar(CalendarPropName::Displayname)
);
let propname: UnionPropName = UnionProp::Extension(ExtensionProp::Hello).into();
assert_eq!(
propname,
UnionPropName::Extension(ExtensionsPropName::Hello)
);
} }