1 // SPDX-FileCopyrightText: 2020 Robin Krahl <[email protected]>
2 // SPDX-License-Identifier: Apache-2.0 or MIT
3
4 //! A derive macro for the [`merge::Merge`][] trait.
5 //!
6 //! See the documentation for the [`merge`][] crate for more information.
7 //!
8 //! [`merge`]: https://lib.rs/crates/merge
9 //! [`merge::Merge`]: https://docs.rs/merge/latest/merge/trait.Merge.html
10
11 extern crate proc_macro;
12
13 use proc_macro2::TokenStream;
14 use quote::{quote, quote_spanned};
15 use std::convert::TryFrom;
16 use syn::{Error, Result, Token};
17
18 struct Field {
19 name: syn::Member,
20 span: proc_macro2::Span,
21 attrs: FieldAttrs,
22 }
23
24 #[derive(Default)]
25 struct FieldAttrs {
26 skip: bool,
27 strategy: Option<syn::Path>,
28 }
29
30 enum FieldAttr {
31 Skip,
32 Strategy(syn::Path),
33 }
34
35 #[proc_macro_derive(Merge, attributes(merge))]
merge_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream36 pub fn merge_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
37 let ast = syn::parse(input).unwrap();
38 impl_merge(&ast)
39 .unwrap_or_else(Error::into_compile_error)
40 .into()
41 }
42
impl_merge(ast: &syn::DeriveInput) -> Result<TokenStream>43 fn impl_merge(ast: &syn::DeriveInput) -> Result<TokenStream> {
44 let name = &ast.ident;
45
46 if let syn::Data::Struct(syn::DataStruct { ref fields, .. }) = ast.data {
47 impl_merge_for_struct(name, fields)
48 } else {
49 Err(Error::new_spanned(
50 ast,
51 "merge::Merge can only be derived for structs",
52 ))
53 }
54 }
55
impl_merge_for_struct(name: &syn::Ident, fields: &syn::Fields) -> Result<TokenStream>56 fn impl_merge_for_struct(name: &syn::Ident, fields: &syn::Fields) -> Result<TokenStream> {
57 let assignments = gen_assignments(fields)?;
58
59 Ok(quote! {
60 impl ::merge::Merge for #name {
61 fn merge(&mut self, other: Self) {
62 #assignments
63 }
64 }
65 })
66 }
67
gen_assignments(fields: &syn::Fields) -> Result<TokenStream>68 fn gen_assignments(fields: &syn::Fields) -> Result<TokenStream> {
69 let fields = fields
70 .iter()
71 .enumerate()
72 .map(Field::try_from)
73 .collect::<Result<Vec<_>>>()?;
74 let assignments = fields
75 .iter()
76 .filter(|f| !f.attrs.skip)
77 .map(|f| gen_assignment(&f));
78 Ok(quote! {
79 #( #assignments )*
80 })
81 }
82
gen_assignment(field: &Field) -> TokenStream83 fn gen_assignment(field: &Field) -> TokenStream {
84 use syn::spanned::Spanned;
85
86 let name = &field.name;
87 if let Some(strategy) = &field.attrs.strategy {
88 quote_spanned!(strategy.span()=> #strategy(&mut self.#name, other.#name);)
89 } else {
90 quote_spanned!(field.span=> ::merge::Merge::merge(&mut self.#name, other.#name);)
91 }
92 }
93
94 impl TryFrom<(usize, &syn::Field)> for Field {
95 type Error = syn::Error;
96
try_from(data: (usize, &syn::Field)) -> std::result::Result<Self, Self::Error>97 fn try_from(data: (usize, &syn::Field)) -> std::result::Result<Self, Self::Error> {
98 use syn::spanned::Spanned;
99
100 let (index, field) = data;
101 Ok(Field {
102 name: if let Some(ident) = &field.ident {
103 syn::Member::Named(ident.clone())
104 } else {
105 syn::Member::Unnamed(index.into())
106 },
107 span: field.span(),
108 attrs: FieldAttrs::new(field.attrs.iter())?,
109 })
110 }
111 }
112
113 impl FieldAttrs {
new<'a, I: Iterator<Item = &'a syn::Attribute>>(iter: I) -> Result<Self>114 fn new<'a, I: Iterator<Item = &'a syn::Attribute>>(iter: I) -> Result<Self> {
115 let mut field_attrs = Self::default();
116
117 for attr in iter {
118 if !attr.path().is_ident("merge") {
119 continue;
120 }
121
122 let parser = syn::punctuated::Punctuated::<FieldAttr, Token![,]>::parse_terminated;
123 for attr in attr.parse_args_with(parser)? {
124 field_attrs.apply(attr);
125 }
126 }
127
128 Ok(field_attrs)
129 }
130
apply(&mut self, attr: FieldAttr)131 fn apply(&mut self, attr: FieldAttr) {
132 match attr {
133 FieldAttr::Skip => self.skip = true,
134 FieldAttr::Strategy(path) => self.strategy = Some(path),
135 }
136 }
137 }
138
139 impl syn::parse::Parse for FieldAttr {
parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self>140 fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
141 let name: syn::Ident = input.parse()?;
142 if name == "skip" {
143 // TODO check remaining stream
144 Ok(FieldAttr::Skip)
145 } else if name == "strategy" {
146 let _: Token![=] = input.parse()?;
147 let path: syn::Path = input.parse()?;
148 Ok(FieldAttr::Strategy(path))
149 } else {
150 Err(Error::new_spanned(
151 &name,
152 format!("Unexpected attribute: {}", name),
153 ))
154 }
155 }
156 }
157