From 043ce8bcd07e5c9dbb1879cc3843e95b28c080af Mon Sep 17 00:00:00 2001 From: Lennart <18233294+lennart-k@users.noreply.github.com> Date: Sun, 22 Dec 2024 12:44:19 +0100 Subject: [PATCH] xml: Move XmlRoot implementation into dedicated derive macro --- crates/xml/derive/src/de/attrs.rs | 3 ++- crates/xml/derive/src/de/de_struct.rs | 23 +++++++++++------------ crates/xml/derive/src/lib.rs | 13 ++++++++++++- crates/xml/src/de.rs | 2 +- crates/xml/tests/de_struct.rs | 21 +++++++++++---------- 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/crates/xml/derive/src/de/attrs.rs b/crates/xml/derive/src/de/attrs.rs index 8662934..bfaa54c 100644 --- a/crates/xml/derive/src/de/attrs.rs +++ b/crates/xml/derive/src/de/attrs.rs @@ -24,7 +24,8 @@ pub struct VariantAttrs { #[darling(attributes(xml))] pub struct EnumAttrs { #[darling(flatten)] - container: ContainerAttrs, + pub container: ContainerAttrs, + pub untagged: Flag, } #[derive(Default, FromDeriveInput, Clone)] diff --git a/crates/xml/derive/src/de/de_struct.rs b/crates/xml/derive/src/de/de_struct.rs index 38a8ae4..87fbccf 100644 --- a/crates/xml/derive/src/de/de_struct.rs +++ b/crates/xml/derive/src/de/de_struct.rs @@ -56,6 +56,17 @@ pub struct NamedStruct { } impl NamedStruct { + pub fn impl_xml_root(&self) -> proc_macro2::TokenStream { + let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl(); + let ident = &self.ident; + let root = self.attrs.root.as_ref().expect("No root attribute found"); + quote! { + impl #impl_generics ::rustical_xml::XmlRoot for #ident #type_generics #where_clause { + fn root_tag() -> &'static [u8] { #root } + } + } + } + pub fn impl_de(&self) -> proc_macro2::TokenStream { let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl(); let ident = &self.ident; @@ -76,21 +87,9 @@ impl NamedStruct { let builder_field_builds = self.fields.iter().map(Field::builder_field_build); - let xml_root_impl = if let Some(root) = &self.attrs.root { - quote! { - impl #impl_generics ::rustical_xml::XmlRoot for #ident #type_generics #where_clause { - fn root_tag() -> &'static [u8] { #root } - } - } - } else { - quote! {} - }; - let invalid_field_branch = invalid_field_branch(self.attrs.allow_invalid.is_present()); quote! { - #xml_root_impl - impl #impl_generics ::rustical_xml::XmlDeserialize for #ident #type_generics #where_clause { fn deserialize( reader: &mut quick_xml::NsReader, diff --git a/crates/xml/derive/src/lib.rs b/crates/xml/derive/src/lib.rs index 205e4d3..937b112 100644 --- a/crates/xml/derive/src/lib.rs +++ b/crates/xml/derive/src/lib.rs @@ -2,7 +2,6 @@ use core::panic; use syn::{parse_macro_input, DeriveInput}; mod de; - use de::{impl_de_enum, NamedStruct}; #[proc_macro_derive(XmlDeserialize, attributes(xml))] @@ -16,3 +15,15 @@ pub fn derive_xml_deserialize(input: proc_macro::TokenStream) -> proc_macro::Tok } .into() } + +#[proc_macro_derive(XmlRoot, attributes(xml))] +pub fn derive_xml_root(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + match &input.data { + syn::Data::Struct(s) => NamedStruct::parse(&input, s).impl_xml_root(), + syn::Data::Enum(_) => panic!("Enum not supported as root"), + syn::Data::Union(_) => panic!("Union not supported as root"), + } + .into() +} diff --git a/crates/xml/src/de.rs b/crates/xml/src/de.rs index f1b0f33..e85c320 100644 --- a/crates/xml/src/de.rs +++ b/crates/xml/src/de.rs @@ -86,4 +86,4 @@ pub trait XmlRootParseStr<'i>: XmlRoot + XmlDeserialize { } } -impl<'i, T: XmlRoot + XmlDeserialize> XmlRootParseStr<'i> for T {} +impl XmlRootParseStr<'_> for T {} diff --git a/crates/xml/tests/de_struct.rs b/crates/xml/tests/de_struct.rs index e194579..fde54c7 100644 --- a/crates/xml/tests/de_struct.rs +++ b/crates/xml/tests/de_struct.rs @@ -2,10 +2,11 @@ use rustical_xml::de::XmlRootParseStr; use rustical_xml::{Unit, Unparsed, XmlDeserialize}; use std::collections::HashSet; use std::io::BufRead; +use xml_derive::XmlRoot; #[test] fn test_struct_text_field() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document")] struct Document { #[xml(ty = "text")] @@ -26,7 +27,7 @@ fn test_struct_text_field() { #[test] fn test_struct_document() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document")] struct Document { child: Child, @@ -51,7 +52,7 @@ fn test_struct_document() { #[test] fn test_struct_rename_field() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document")] struct Document { #[xml(rename = b"ok-wow")] @@ -77,7 +78,7 @@ fn test_struct_rename_field() { #[test] fn test_struct_optional_field() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document")] struct Document { #[xml(default = "Default::default")] @@ -96,7 +97,7 @@ fn test_struct_optional_field() { #[test] fn test_struct_vec() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document")] struct Document { #[xml(rename = b"child", flatten)] @@ -124,7 +125,7 @@ fn test_struct_vec() { #[test] fn test_struct_set() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document")] struct Document { #[xml(rename = b"child", flatten)] @@ -152,7 +153,7 @@ fn test_struct_set() { #[test] fn test_struct_ns() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document", ns_strict)] struct Document { #[xml(ns = b"hello")] @@ -165,7 +166,7 @@ fn test_struct_ns() { #[test] fn test_struct_attr() { - #[derive(Debug, XmlDeserialize, PartialEq)] + #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)] #[xml(root = b"document", ns_strict)] struct Document { #[xml(ns = b"hello")] @@ -192,7 +193,7 @@ fn test_struct_attr() { #[test] fn test_struct_generics() { - #[derive(XmlDeserialize)] + #[derive(XmlDeserialize, XmlRoot)] #[xml(root = b"document", ns_strict)] struct Document { child: T, @@ -212,7 +213,7 @@ fn test_struct_generics() { #[test] fn test_struct_unparsed() { - #[derive(XmlDeserialize)] + #[derive(XmlDeserialize, XmlRoot)] #[xml(root = b"document", ns_strict)] struct Document { child: Unparsed,