1 // SPDX-FileCopyrightText: 2020 Robin Krahl <[email protected]>
2 // SPDX-License-Identifier: Apache-2.0 or MIT
3 
4 //! A derive macro for the [`merge::Merge`][] trait.
5 //!
6 //! See the documentation for the [`merge`][] crate for more information.
7 //!
8 //! [`merge`]: https://lib.rs/crates/merge
9 //! [`merge::Merge`]: https://docs.rs/merge/latest/merge/trait.Merge.html
10 
11 extern crate proc_macro;
12 
13 use proc_macro2::TokenStream;
14 use quote::{quote, quote_spanned};
15 use std::convert::TryFrom;
16 use syn::{Error, Result, Token};
17 
18 struct Field {
19     name: syn::Member,
20     span: proc_macro2::Span,
21     attrs: FieldAttrs,
22 }
23 
24 #[derive(Default)]
25 struct FieldAttrs {
26     skip: bool,
27     strategy: Option<syn::Path>,
28 }
29 
30 enum FieldAttr {
31     Skip,
32     Strategy(syn::Path),
33 }
34 
35 #[proc_macro_derive(Merge, attributes(merge))]
merge_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream36 pub fn merge_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
37     let ast = syn::parse(input).unwrap();
38     impl_merge(&ast)
39         .unwrap_or_else(Error::into_compile_error)
40         .into()
41 }
42 
impl_merge(ast: &syn::DeriveInput) -> Result<TokenStream>43 fn impl_merge(ast: &syn::DeriveInput) -> Result<TokenStream> {
44     let name = &ast.ident;
45 
46     if let syn::Data::Struct(syn::DataStruct { ref fields, .. }) = ast.data {
47         impl_merge_for_struct(name, fields)
48     } else {
49         Err(Error::new_spanned(
50             ast,
51             "merge::Merge can only be derived for structs",
52         ))
53     }
54 }
55 
impl_merge_for_struct(name: &syn::Ident, fields: &syn::Fields) -> Result<TokenStream>56 fn impl_merge_for_struct(name: &syn::Ident, fields: &syn::Fields) -> Result<TokenStream> {
57     let assignments = gen_assignments(fields)?;
58 
59     Ok(quote! {
60         impl ::merge::Merge for #name {
61             fn merge(&mut self, other: Self) {
62                 #assignments
63             }
64         }
65     })
66 }
67 
gen_assignments(fields: &syn::Fields) -> Result<TokenStream>68 fn gen_assignments(fields: &syn::Fields) -> Result<TokenStream> {
69     let fields = fields
70         .iter()
71         .enumerate()
72         .map(Field::try_from)
73         .collect::<Result<Vec<_>>>()?;
74     let assignments = fields
75         .iter()
76         .filter(|f| !f.attrs.skip)
77         .map(|f| gen_assignment(&f));
78     Ok(quote! {
79         #( #assignments )*
80     })
81 }
82 
gen_assignment(field: &Field) -> TokenStream83 fn gen_assignment(field: &Field) -> TokenStream {
84     use syn::spanned::Spanned;
85 
86     let name = &field.name;
87     if let Some(strategy) = &field.attrs.strategy {
88         quote_spanned!(strategy.span()=> #strategy(&mut self.#name, other.#name);)
89     } else {
90         quote_spanned!(field.span=> ::merge::Merge::merge(&mut self.#name, other.#name);)
91     }
92 }
93 
94 impl TryFrom<(usize, &syn::Field)> for Field {
95     type Error = syn::Error;
96 
try_from(data: (usize, &syn::Field)) -> std::result::Result<Self, Self::Error>97     fn try_from(data: (usize, &syn::Field)) -> std::result::Result<Self, Self::Error> {
98         use syn::spanned::Spanned;
99 
100         let (index, field) = data;
101         Ok(Field {
102             name: if let Some(ident) = &field.ident {
103                 syn::Member::Named(ident.clone())
104             } else {
105                 syn::Member::Unnamed(index.into())
106             },
107             span: field.span(),
108             attrs: FieldAttrs::new(field.attrs.iter())?,
109         })
110     }
111 }
112 
113 impl FieldAttrs {
new<'a, I: Iterator<Item = &'a syn::Attribute>>(iter: I) -> Result<Self>114     fn new<'a, I: Iterator<Item = &'a syn::Attribute>>(iter: I) -> Result<Self> {
115         let mut field_attrs = Self::default();
116 
117         for attr in iter {
118             if !attr.path().is_ident("merge") {
119                 continue;
120             }
121 
122             let parser = syn::punctuated::Punctuated::<FieldAttr, Token![,]>::parse_terminated;
123             for attr in attr.parse_args_with(parser)? {
124                 field_attrs.apply(attr);
125             }
126         }
127 
128         Ok(field_attrs)
129     }
130 
apply(&mut self, attr: FieldAttr)131     fn apply(&mut self, attr: FieldAttr) {
132         match attr {
133             FieldAttr::Skip => self.skip = true,
134             FieldAttr::Strategy(path) => self.strategy = Some(path),
135         }
136     }
137 }
138 
139 impl syn::parse::Parse for FieldAttr {
parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self>140     fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
141         let name: syn::Ident = input.parse()?;
142         if name == "skip" {
143             // TODO check remaining stream
144             Ok(FieldAttr::Skip)
145         } else if name == "strategy" {
146             let _: Token![=] = input.parse()?;
147             let path: syn::Path = input.parse()?;
148             Ok(FieldAttr::Strategy(path))
149         } else {
150             Err(Error::new_spanned(
151                 &name,
152                 format!("Unexpected attribute: {}", name),
153             ))
154         }
155     }
156 }
157