xml: untagged enums

This commit is contained in:
Lennart
2024-12-22 18:12:15 +01:00
parent 241b356e44
commit 9813fb5f95
4 changed files with 204 additions and 52 deletions

View File

@@ -13,11 +13,12 @@ pub struct TagAttrs {
pub ns: Option<LitByteStr>,
}
#[derive(Default, FromVariant, Clone)]
#[derive(Default, FromVariant)]
#[darling(attributes(xml))]
pub struct VariantAttrs {
#[darling(flatten)]
pub common: TagAttrs,
pub other: Flag,
}
#[derive(Default, FromDeriveInput, Clone)]

View File

@@ -1,33 +1,129 @@
use super::attrs::EnumAttrs;
use crate::de::attrs::VariantAttrs;
use darling::{FromDeriveInput, FromVariant};
use heck::ToKebabCase;
use quote::quote;
use syn::{DataEnum, DeriveInput, Fields, FieldsUnnamed, Variant};
use syn::{DataEnum, DeriveInput, Fields, FieldsUnnamed};
use super::attrs::EnumAttrs;
pub struct Variant {
variant: syn::Variant,
attrs: VariantAttrs,
}
impl Variant {
fn ident(&self) -> &syn::Ident {
&self.variant.ident
}
pub fn xml_name(&self) -> syn::LitByteStr {
self.attrs
.common
.rename
.to_owned()
.unwrap_or(syn::LitByteStr::new(
self.ident().to_string().to_kebab_case().as_bytes(),
self.ident().span(),
))
}
pub fn tagged_branch(&self) -> proc_macro2::TokenStream {
let ident = self.ident();
let variant_name = self.xml_name();
match (self.attrs.other.is_present(), &self.variant.fields) {
(_, Fields::Named(_)) => {
panic!(
"struct variants are not supported, please use a tuple variant with a struct"
)
}
(false, Fields::Unnamed(FieldsUnnamed { unnamed, .. })) => {
if unnamed.len() != 1 {
panic!("tuple variants should contain exactly one element");
}
let field = unnamed.iter().next().unwrap();
quote! {
#variant_name => {
let val = #field::deserialize(reader, start, empty)?;
Ok(Self::#ident(val))
}
}
}
(false, Fields::Unit) => {
quote! {
#variant_name => {
// Make sure that content is still consumed
::rustical_xml::Unit::deserialize(reader, start, empty)?;
Ok(Self::#ident)
}
}
}
(true, Fields::Unnamed(_)) => {
panic!("other for tuple enums not implemented yet")
}
(true, Fields::Unit) => {
quote! {
_ => {
// Make sure that content is still consumed
::rustical_xml::Unit::deserialize(reader, start, empty)?;
Ok(Self::#ident)
}
}
}
}
}
pub fn untagged_branch(&self) -> proc_macro2::TokenStream {
if self.attrs.other.is_present() {
panic!("using the other flag on an untagged variant is futile");
}
let ident = self.ident();
match &self.variant.fields {
Fields::Named(_) => {
panic!(
"struct variants are not supported, please use a tuple variant with a struct"
)
}
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
if unnamed.len() != 1 {
panic!("tuple variants should contain exactly one element");
}
let field = unnamed.iter().next().unwrap();
quote! {
if let Ok(val) = #field::deserialize(reader, start, empty) {
return Ok(Self::#ident(val));
}
}
}
Fields::Unit => {
quote! {
// Make sure that content is still consumed
if let Ok(_) = ::rustical_xml::Unit::deserialize(reader, start, empty) {
return Ok(Self::#ident);
}
}
}
}
}
}
pub struct Enum {
attrs: EnumAttrs,
variants: Vec<syn::Variant>,
variants: Vec<Variant>,
ident: syn::Ident,
generics: syn::Generics,
}
impl Enum {
pub fn impl_de(&self) -> proc_macro2::TokenStream {
fn impl_de_untagged(&self) -> proc_macro2::TokenStream {
let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl();
let name = &self.ident;
let variants = self.variants.iter().map(|variant| {
let attrs = VariantAttrs::from_variant(variant).unwrap();
let variant_name = attrs.common.rename.unwrap_or(syn::LitByteStr::new(
variant.ident.to_string().to_kebab_case().as_bytes(),
variant.ident.span(),
));
let branch = enum_variant_branch(variant);
quote! { #variant_name => { #branch } }
});
let variant_branches = self
.variants
.iter()
.map(|variant| variant.untagged_branch());
quote! {
impl #impl_generics ::rustical_xml::XmlDeserialize for #name #type_generics #where_clause {
@@ -36,12 +132,31 @@ impl Enum {
start: &quick_xml::events::BytesStart,
empty: bool
) -> Result<Self, rustical_xml::XmlDeError> {
use quick_xml::events::Event;
#(#variant_branches);*
Err(rustical_xml::XmlDeError::UnknownError)
}
}
}
}
fn impl_de_tagged(&self) -> proc_macro2::TokenStream {
let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl();
let name = &self.ident;
let variant_branches = self.variants.iter().map(Variant::tagged_branch);
quote! {
impl #impl_generics ::rustical_xml::XmlDeserialize for #name #type_generics #where_clause {
fn deserialize<R: std::io::BufRead>(
reader: &mut quick_xml::NsReader<R>,
start: &quick_xml::events::BytesStart,
empty: bool
) -> Result<Self, rustical_xml::XmlDeError> {
let (_ns, name) = reader.resolve_element(start.name());
match name.as_ref() {
#(#variants),*
#(#variant_branches),*
name => {
// Handle invalid variant name
Err(rustical_xml::XmlDeError::InvalidVariant(String::from_utf8_lossy(name).to_string()))
@@ -52,41 +167,28 @@ impl Enum {
}
}
pub fn impl_de(&self) -> proc_macro2::TokenStream {
match self.attrs.untagged.is_present() {
true => self.impl_de_untagged(),
false => self.impl_de_tagged(),
}
}
pub fn parse(input: &DeriveInput, data: &DataEnum) -> Self {
let attrs = EnumAttrs::from_derive_input(input).unwrap();
Self {
variants: data
.variants
.iter()
.map(|variant| Variant {
attrs: VariantAttrs::from_variant(variant).unwrap(),
variant: variant.to_owned(),
})
.collect(),
attrs,
variants: data.variants.iter().cloned().collect(),
ident: input.ident.to_owned(),
generics: input.generics.to_owned(),
}
}
}
pub fn enum_variant_branch(variant: &Variant) -> proc_macro2::TokenStream {
let ident = &variant.ident;
match &variant.fields {
Fields::Named(_) => {
panic!("struct variants are not supported, please use a tuple variant with a struct")
}
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
if unnamed.len() != 1 {
panic!("tuple variants should contain exactly one element");
}
let field = unnamed.iter().next().unwrap();
quote! {
let val = #field::deserialize(reader, start, empty)?;
Ok(Self::#ident(val))
}
}
Fields::Unit => {
quote! {
// Make sure that content is still consumed
::rustical_xml::Unit::deserialize(reader, start, empty)?;
Ok(Self::#ident)
}
}
}
}

View File

@@ -1,5 +1,3 @@
use crate::de::field;
use super::attrs::{ContainerAttrs, FieldAttrs, FieldType};
use darling::FromField;
use heck::ToKebabCase;

View File

@@ -1,8 +1,8 @@
use rustical_xml::{de::XmlRootParseStr, XmlDeserialize, XmlRoot};
use rustical_xml::{de::XmlRootParseStr, Unit, XmlDeserialize, XmlRoot};
use std::io::BufRead;
#[test]
fn test_struct_untagged_enum() {
fn test_struct_tagged_enum() {
#[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"propfind")]
struct Propfind {
@@ -11,14 +11,15 @@ fn test_struct_untagged_enum() {
#[derive(Debug, XmlDeserialize, PartialEq)]
struct Prop {
#[xml(ty = "untagged")]
prop: PropEnum,
#[xml(ty = "untagged", flatten)]
prop: Vec<PropEnum>,
}
#[derive(Debug, XmlDeserialize, PartialEq)]
enum PropEnum {
A,
B,
Displayname(String),
}
let doc = Propfind::parse_str(
@@ -26,6 +27,8 @@ fn test_struct_untagged_enum() {
<propfind>
<prop>
<b/>
<a/>
<displayname>Hello!</displayname>
</prop>
</propfind>
"#,
@@ -34,7 +37,55 @@ fn test_struct_untagged_enum() {
assert_eq!(
doc,
Propfind {
prop: Prop { prop: PropEnum::B }
prop: Prop {
prop: vec![
PropEnum::B,
PropEnum::A,
PropEnum::Displayname("Hello!".to_owned())
]
}
}
);
}
#[test]
fn test_tagged_enum_complex() {
#[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"propfind")]
struct Propfind {
prop: PropStruct,
}
#[derive(Debug, XmlDeserialize, PartialEq)]
struct PropStruct {
#[xml(ty = "untagged", flatten)]
prop: Vec<Prop>,
}
#[derive(Debug, XmlDeserialize, PartialEq)]
enum Prop {
Nice(Nice),
#[xml(other)]
Invalid,
}
#[derive(Debug, XmlDeserialize, PartialEq)]
struct Nice {
nice: Unit,
}
let asd = Propfind::parse_str(
r#"
<propfind>
<prop>
<nice>
<nice />
</nice>
<wtf />
</prop>
</propfind>
"#,
)
.unwrap();
dbg!(asd);
}