xml: some refactoring

This commit is contained in:
Lennart
2024-12-22 15:18:43 +01:00
parent cd4137cda1
commit 241b356e44
5 changed files with 109 additions and 82 deletions

View File

@@ -1,9 +1,69 @@
use crate::de::attrs::VariantAttrs; use crate::de::attrs::VariantAttrs;
use darling::FromVariant; use darling::{FromDeriveInput, FromVariant};
use heck::ToKebabCase; use heck::ToKebabCase;
use quote::quote; use quote::quote;
use syn::{DataEnum, DeriveInput, Fields, FieldsUnnamed, Variant}; use syn::{DataEnum, DeriveInput, Fields, FieldsUnnamed, Variant};
use super::attrs::EnumAttrs;
pub struct Enum {
attrs: EnumAttrs,
variants: Vec<syn::Variant>,
ident: syn::Ident,
generics: syn::Generics,
}
impl Enum {
pub fn impl_de(&self) -> proc_macro2::TokenStream {
let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl();
let name = &self.ident;
let variants = self.variants.iter().map(|variant| {
let attrs = VariantAttrs::from_variant(variant).unwrap();
let variant_name = attrs.common.rename.unwrap_or(syn::LitByteStr::new(
variant.ident.to_string().to_kebab_case().as_bytes(),
variant.ident.span(),
));
let branch = enum_variant_branch(variant);
quote! { #variant_name => { #branch } }
});
quote! {
impl #impl_generics ::rustical_xml::XmlDeserialize for #name #type_generics #where_clause {
fn deserialize<R: std::io::BufRead>(
reader: &mut quick_xml::NsReader<R>,
start: &quick_xml::events::BytesStart,
empty: bool
) -> Result<Self, rustical_xml::XmlDeError> {
use quick_xml::events::Event;
let (_ns, name) = reader.resolve_element(start.name());
match name.as_ref() {
#(#variants),*
name => {
// Handle invalid variant name
Err(rustical_xml::XmlDeError::InvalidVariant(String::from_utf8_lossy(name).to_string()))
}
}
}
}
}
}
pub fn parse(input: &DeriveInput, data: &DataEnum) -> Self {
let attrs = EnumAttrs::from_derive_input(input).unwrap();
Self {
attrs,
variants: data.variants.iter().cloned().collect(),
ident: input.ident.to_owned(),
generics: input.generics.to_owned(),
}
}
}
pub fn enum_variant_branch(variant: &Variant) -> proc_macro2::TokenStream { pub fn enum_variant_branch(variant: &Variant) -> proc_macro2::TokenStream {
let ident = &variant.ident; let ident = &variant.ident;
@@ -30,46 +90,3 @@ pub fn enum_variant_branch(variant: &Variant) -> proc_macro2::TokenStream {
} }
} }
} }
pub fn impl_de_enum(input: &DeriveInput, data: &DataEnum) -> proc_macro2::TokenStream {
let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
let name = &input.ident;
let variants = data.variants.iter().map(|variant| {
let attrs = VariantAttrs::from_variant(variant).unwrap();
let variant_name = attrs.common.rename.unwrap_or(syn::LitByteStr::new(
variant.ident.to_string().to_kebab_case().as_bytes(),
variant.ident.span(),
));
let branch = enum_variant_branch(variant);
dbg!(&variant_name);
quote! {
#variant_name => {
#branch
}
}
});
quote! {
impl #impl_generics ::rustical_xml::XmlDeserialize for #name #type_generics #where_clause {
fn deserialize<R: std::io::BufRead>(
reader: &mut quick_xml::NsReader<R>,
start: &quick_xml::events::BytesStart,
empty: bool
) -> Result<Self, rustical_xml::XmlDeError> {
use quick_xml::events::Event;
let (_ns, name) = reader.resolve_element(start.name());
match name.as_ref() {
#(#variants)*
name => {
// Handle invalid variant name
Err(rustical_xml::XmlDeError::InvalidVariant(String::from_utf8_lossy(name).to_string()))
}
}
}
}
}
}

View File

@@ -94,17 +94,17 @@ impl NamedStruct {
// initialise fields // initialise fields
struct StructBuilder #type_generics #where_clause { struct StructBuilder #type_generics #where_clause {
#(#builder_fields)* #(#builder_fields),*
} }
let mut builder = StructBuilder { let mut builder = StructBuilder {
#(#builder_field_inits)* #(#builder_field_inits),*
}; };
for attr in start.attributes() { for attr in start.attributes() {
let attr = attr?; let attr = attr?;
match attr.key.as_ref() { match attr.key.as_ref() {
#(#attr_field_branches)* #(#attr_field_branches),*
_ => { #invalid_field_branch } _ => { #invalid_field_branch }
} }
} }
@@ -122,8 +122,8 @@ impl NamedStruct {
let empty = matches!(event, Event::Empty(_)); let empty = matches!(event, Event::Empty(_));
let (ns, name) = reader.resolve_element(start.name()); let (ns, name) = reader.resolve_element(start.name());
match (ns, name.as_ref()) { match (ns, name.as_ref()) {
#(#named_field_branches)* #(#named_field_branches),*
#(#untagged_field_branches)* #(#untagged_field_branches),*
_ => { #invalid_field_branch } _ => { #invalid_field_branch }
} }
} }
@@ -156,7 +156,7 @@ impl NamedStruct {
} }
Ok(Self { Ok(Self {
#(#builder_field_builds)* #(#builder_field_builds),*
}) })
} }
} }

View File

@@ -1,3 +1,5 @@
use crate::de::field;
use super::attrs::{ContainerAttrs, FieldAttrs, FieldType}; use super::attrs::{ContainerAttrs, FieldAttrs, FieldType};
use darling::FromField; use darling::FromField;
use heck::ToKebabCase; use heck::ToKebabCase;
@@ -45,7 +47,9 @@ impl Field {
container_attrs, container_attrs,
} }
} }
pub fn de_name(&self) -> syn::LitByteStr {
/// Field name in XML
pub fn xml_name(&self) -> syn::LitByteStr {
self.attrs self.attrs
.common .common
.rename .rename
@@ -56,10 +60,12 @@ impl Field {
)) ))
} }
/// Whether to enforce the correct XML namespace
pub fn ns_strict(&self) -> bool { pub fn ns_strict(&self) -> bool {
self.attrs.common.ns_strict.is_present() || self.container_attrs.ns_strict.is_present() self.attrs.common.ns_strict.is_present() || self.container_attrs.ns_strict.is_present()
} }
/// Field identifier
pub fn field_ident(&self) -> &syn::Ident { pub fn field_ident(&self) -> &syn::Ident {
self.field self.field
.ident .ident
@@ -67,48 +73,52 @@ impl Field {
.expect("tuple structs not supported") .expect("tuple structs not supported")
} }
/// Field type
pub fn ty(&self) -> &syn::Type { pub fn ty(&self) -> &syn::Type {
&self.field.ty &self.field.ty
} }
/// Field in the builder struct for the deserializer
pub fn builder_field(&self) -> proc_macro2::TokenStream { pub fn builder_field(&self) -> proc_macro2::TokenStream {
let field_ident = self.field_ident(); let field_ident = self.field_ident();
let ty = self.ty(); let ty = self.ty();
match (self.attrs.flatten.is_present(), &self.attrs.default) {
(_, Some(_default)) => quote! { #field_ident: #ty, }, let builder_field_type = match (self.attrs.flatten.is_present(), &self.attrs.default) {
(_, Some(_default)) => quote! { #ty },
(true, None) => { (true, None) => {
let generic_type = get_generic_type(ty).expect("flatten attribute only implemented for explicit generics (rustical_xml will assume the first generic as the inner type)"); let generic_type = get_generic_type(ty).expect("flatten attribute only implemented for explicit generics (rustical_xml will assume the first generic as the inner type)");
quote! { #field_ident: Vec<#generic_type>, } quote! { Vec<#generic_type> }
} }
(false, None) => quote! { #field_ident: Option<#ty>, }, (false, None) => quote! { Option<#ty> },
} };
quote! { #field_ident: #builder_field_type }
} }
/// Field initialiser in the builder struct for the deserializer
pub fn builder_field_init(&self) -> proc_macro2::TokenStream { pub fn builder_field_init(&self) -> proc_macro2::TokenStream {
let field_ident = self.field_ident(); let field_ident = self.field_ident();
match (self.attrs.flatten.is_present(), &self.attrs.default) { let builder_field_initialiser = match (self.attrs.flatten.is_present(), &self.attrs.default)
(_, Some(default)) => quote! { #field_ident: #default(), }, {
(true, None) => quote! { #field_ident: vec![], }, (_, Some(default)) => quote! { #default() },
(false, None) => quote! { #field_ident: None, }, (true, None) => quote! { vec![] },
} (false, None) => quote! { None },
};
quote! { #field_ident: #builder_field_initialiser }
} }
/// Map builder field to target field
pub fn builder_field_build(&self) -> proc_macro2::TokenStream { pub fn builder_field_build(&self) -> proc_macro2::TokenStream {
let field_ident = self.field_ident(); let field_ident = self.field_ident();
match ( let builder_value = match (
self.attrs.flatten.is_present(), self.attrs.flatten.is_present(),
self.attrs.default.is_some(), self.attrs.default.is_some(),
) { ) {
(true, _) => quote! { (true, _) => quote! { FromIterator::from_iter(builder.#field_ident.into_iter()) },
#field_ident: FromIterator::from_iter(builder.#field_ident.into_iter()) (false, true) => quote! { builder.#field_ident },
}, (false, false) => quote! { builder.#field_ident.expect("todo: handle missing field") },
(false, true) => quote! { };
#field_ident: builder.#field_ident, quote! { #field_ident: #builder_value }
},
(false, false) => quote! {
#field_ident: builder.#field_ident.expect("todo: handle missing field"),
},
}
} }
pub fn named_branch(&self) -> Option<proc_macro2::TokenStream> { pub fn named_branch(&self) -> Option<proc_macro2::TokenStream> {
@@ -126,7 +136,7 @@ impl Field {
quote! {_} quote! {_}
}; };
let field_name = self.de_name(); let field_name = self.xml_name();
let field_ident = self.field_ident(); let field_ident = self.field_ident();
let deserializer = self.ty(); let deserializer = self.ty();
let value = quote! { <#deserializer as rustical_xml::XmlDeserialize>::deserialize(reader, &start, empty)? }; let value = quote! { <#deserializer as rustical_xml::XmlDeserialize>::deserialize(reader, &start, empty)? };
@@ -147,7 +157,7 @@ impl Field {
}; };
Some(quote! { Some(quote! {
(#namespace_match, #field_name) => { #assignment; }, (#namespace_match, #field_name) => { #assignment; }
}) })
} }
@@ -163,13 +173,13 @@ impl Field {
quote! { quote! {
_ => { _ => {
builder.#field_ident.push(<#deserializer as rustical_xml::XmlDeserialize>::deserialize(reader, &start, empty)?); builder.#field_ident.push(<#deserializer as rustical_xml::XmlDeserialize>::deserialize(reader, &start, empty)?);
}, }
} }
} else { } else {
quote! { quote! {
_ => { _ => {
builder.#field_ident = Some(<#deserializer as rustical_xml::XmlDeserialize>::deserialize(reader, &start, empty)?); builder.#field_ident = Some(<#deserializer as rustical_xml::XmlDeserialize>::deserialize(reader, &start, empty)?);
}, }
} }
}) })
} }
@@ -181,8 +191,8 @@ impl Field {
let field_ident = self.field_ident(); let field_ident = self.field_ident();
let value = wrap_option_if_no_default( let value = wrap_option_if_no_default(
quote! { quote! {
rustical_xml::Value::deserialize(text.as_ref())? rustical_xml::Value::deserialize(text.as_ref())?
}, },
self.attrs.default.is_some(), self.attrs.default.is_some(),
); );
Some(quote! { Some(quote! {
@@ -195,7 +205,7 @@ impl Field {
return None; return None;
} }
let field_ident = self.field_ident(); let field_ident = self.field_ident();
let field_name = self.de_name(); let field_name = self.xml_name();
let value = wrap_option_if_no_default( let value = wrap_option_if_no_default(
quote! { quote! {

View File

@@ -3,6 +3,6 @@ mod de_enum;
mod de_struct; mod de_struct;
mod field; mod field;
pub use de_enum::impl_de_enum; pub use de_enum::Enum;
pub use de_struct::NamedStruct; pub use de_struct::NamedStruct;
pub use field::Field; pub use field::Field;

View File

@@ -2,14 +2,14 @@ use core::panic;
use syn::{parse_macro_input, DeriveInput}; use syn::{parse_macro_input, DeriveInput};
mod de; mod de;
use de::{impl_de_enum, NamedStruct}; use de::{Enum, NamedStruct};
#[proc_macro_derive(XmlDeserialize, attributes(xml))] #[proc_macro_derive(XmlDeserialize, attributes(xml))]
pub fn derive_xml_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn derive_xml_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput); let input = parse_macro_input!(input as DeriveInput);
match &input.data { match &input.data {
syn::Data::Enum(e) => impl_de_enum(&input, e), syn::Data::Enum(e) => Enum::parse(&input, e).impl_de(),
syn::Data::Struct(s) => NamedStruct::parse(&input, s).impl_de(), syn::Data::Struct(s) => NamedStruct::parse(&input, s).impl_de(),
syn::Data::Union(_) => panic!("Union not supported"), syn::Data::Union(_) => panic!("Union not supported"),
} }