xml: Move XmlRoot implementation into dedicated derive macro

This commit is contained in:
Lennart
2024-12-22 12:44:19 +01:00
parent 9fe5c00687
commit 043ce8bcd0
5 changed files with 37 additions and 25 deletions

View File

@@ -24,7 +24,8 @@ pub struct VariantAttrs {
#[darling(attributes(xml))] #[darling(attributes(xml))]
pub struct EnumAttrs { pub struct EnumAttrs {
#[darling(flatten)] #[darling(flatten)]
container: ContainerAttrs, pub container: ContainerAttrs,
pub untagged: Flag,
} }
#[derive(Default, FromDeriveInput, Clone)] #[derive(Default, FromDeriveInput, Clone)]

View File

@@ -56,6 +56,17 @@ pub struct NamedStruct {
} }
impl NamedStruct { impl NamedStruct {
pub fn impl_xml_root(&self) -> proc_macro2::TokenStream {
let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl();
let ident = &self.ident;
let root = self.attrs.root.as_ref().expect("No root attribute found");
quote! {
impl #impl_generics ::rustical_xml::XmlRoot for #ident #type_generics #where_clause {
fn root_tag() -> &'static [u8] { #root }
}
}
}
pub fn impl_de(&self) -> proc_macro2::TokenStream { pub fn impl_de(&self) -> proc_macro2::TokenStream {
let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl(); let (impl_generics, type_generics, where_clause) = self.generics.split_for_impl();
let ident = &self.ident; let ident = &self.ident;
@@ -76,21 +87,9 @@ impl NamedStruct {
let builder_field_builds = self.fields.iter().map(Field::builder_field_build); let builder_field_builds = self.fields.iter().map(Field::builder_field_build);
let xml_root_impl = if let Some(root) = &self.attrs.root {
quote! {
impl #impl_generics ::rustical_xml::XmlRoot for #ident #type_generics #where_clause {
fn root_tag() -> &'static [u8] { #root }
}
}
} else {
quote! {}
};
let invalid_field_branch = invalid_field_branch(self.attrs.allow_invalid.is_present()); let invalid_field_branch = invalid_field_branch(self.attrs.allow_invalid.is_present());
quote! { quote! {
#xml_root_impl
impl #impl_generics ::rustical_xml::XmlDeserialize for #ident #type_generics #where_clause { impl #impl_generics ::rustical_xml::XmlDeserialize for #ident #type_generics #where_clause {
fn deserialize<R: BufRead>( fn deserialize<R: BufRead>(
reader: &mut quick_xml::NsReader<R>, reader: &mut quick_xml::NsReader<R>,

View File

@@ -2,7 +2,6 @@ use core::panic;
use syn::{parse_macro_input, DeriveInput}; use syn::{parse_macro_input, DeriveInput};
mod de; mod de;
use de::{impl_de_enum, NamedStruct}; use de::{impl_de_enum, NamedStruct};
#[proc_macro_derive(XmlDeserialize, attributes(xml))] #[proc_macro_derive(XmlDeserialize, attributes(xml))]
@@ -16,3 +15,15 @@ pub fn derive_xml_deserialize(input: proc_macro::TokenStream) -> proc_macro::Tok
} }
.into() .into()
} }
#[proc_macro_derive(XmlRoot, attributes(xml))]
pub fn derive_xml_root(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match &input.data {
syn::Data::Struct(s) => NamedStruct::parse(&input, s).impl_xml_root(),
syn::Data::Enum(_) => panic!("Enum not supported as root"),
syn::Data::Union(_) => panic!("Union not supported as root"),
}
.into()
}

View File

@@ -86,4 +86,4 @@ pub trait XmlRootParseStr<'i>: XmlRoot + XmlDeserialize {
} }
} }
impl<'i, T: XmlRoot + XmlDeserialize> XmlRootParseStr<'i> for T {} impl<T: XmlRoot + XmlDeserialize> XmlRootParseStr<'_> for T {}

View File

@@ -2,10 +2,11 @@ use rustical_xml::de::XmlRootParseStr;
use rustical_xml::{Unit, Unparsed, XmlDeserialize}; use rustical_xml::{Unit, Unparsed, XmlDeserialize};
use std::collections::HashSet; use std::collections::HashSet;
use std::io::BufRead; use std::io::BufRead;
use xml_derive::XmlRoot;
#[test] #[test]
fn test_struct_text_field() { fn test_struct_text_field() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document")] #[xml(root = b"document")]
struct Document { struct Document {
#[xml(ty = "text")] #[xml(ty = "text")]
@@ -26,7 +27,7 @@ fn test_struct_text_field() {
#[test] #[test]
fn test_struct_document() { fn test_struct_document() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document")] #[xml(root = b"document")]
struct Document { struct Document {
child: Child, child: Child,
@@ -51,7 +52,7 @@ fn test_struct_document() {
#[test] #[test]
fn test_struct_rename_field() { fn test_struct_rename_field() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document")] #[xml(root = b"document")]
struct Document { struct Document {
#[xml(rename = b"ok-wow")] #[xml(rename = b"ok-wow")]
@@ -77,7 +78,7 @@ fn test_struct_rename_field() {
#[test] #[test]
fn test_struct_optional_field() { fn test_struct_optional_field() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document")] #[xml(root = b"document")]
struct Document { struct Document {
#[xml(default = "Default::default")] #[xml(default = "Default::default")]
@@ -96,7 +97,7 @@ fn test_struct_optional_field() {
#[test] #[test]
fn test_struct_vec() { fn test_struct_vec() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document")] #[xml(root = b"document")]
struct Document { struct Document {
#[xml(rename = b"child", flatten)] #[xml(rename = b"child", flatten)]
@@ -124,7 +125,7 @@ fn test_struct_vec() {
#[test] #[test]
fn test_struct_set() { fn test_struct_set() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document")] #[xml(root = b"document")]
struct Document { struct Document {
#[xml(rename = b"child", flatten)] #[xml(rename = b"child", flatten)]
@@ -152,7 +153,7 @@ fn test_struct_set() {
#[test] #[test]
fn test_struct_ns() { fn test_struct_ns() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document", ns_strict)] #[xml(root = b"document", ns_strict)]
struct Document { struct Document {
#[xml(ns = b"hello")] #[xml(ns = b"hello")]
@@ -165,7 +166,7 @@ fn test_struct_ns() {
#[test] #[test]
fn test_struct_attr() { fn test_struct_attr() {
#[derive(Debug, XmlDeserialize, PartialEq)] #[derive(Debug, XmlDeserialize, XmlRoot, PartialEq)]
#[xml(root = b"document", ns_strict)] #[xml(root = b"document", ns_strict)]
struct Document { struct Document {
#[xml(ns = b"hello")] #[xml(ns = b"hello")]
@@ -192,7 +193,7 @@ fn test_struct_attr() {
#[test] #[test]
fn test_struct_generics() { fn test_struct_generics() {
#[derive(XmlDeserialize)] #[derive(XmlDeserialize, XmlRoot)]
#[xml(root = b"document", ns_strict)] #[xml(root = b"document", ns_strict)]
struct Document<T: XmlDeserialize> { struct Document<T: XmlDeserialize> {
child: T, child: T,
@@ -212,7 +213,7 @@ fn test_struct_generics() {
#[test] #[test]
fn test_struct_unparsed() { fn test_struct_unparsed() {
#[derive(XmlDeserialize)] #[derive(XmlDeserialize, XmlRoot)]
#[xml(root = b"document", ns_strict)] #[xml(root = b"document", ns_strict)]
struct Document { struct Document {
child: Unparsed, child: Unparsed,