1 extern crate proc_macro;
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::quote;
5 use syn::*;
6 
7 mod container_attributes;
8 mod field_attributes;
9 use container_attributes::ContainerAttributes;
10 use field_attributes::{determine_field_constructor, FieldConstructor};
11 
12 static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
13 static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
14 
15 #[proc_macro_derive(Arbitrary, attributes(arbitrary))]
derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream16 pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
17     let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
18     expand_derive_arbitrary(input)
19         .unwrap_or_else(syn::Error::into_compile_error)
20         .into()
21 }
22 
expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream>23 fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
24     let container_attrs = ContainerAttributes::from_derive_input(&input)?;
25 
26     let (lifetime_without_bounds, lifetime_with_bounds) =
27         build_arbitrary_lifetime(input.generics.clone());
28 
29     let recursive_count = syn::Ident::new(
30         &format!("RECURSIVE_COUNT_{}", input.ident),
31         Span::call_site(),
32     );
33 
34     let arbitrary_method =
35         gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
36     let size_hint_method = gen_size_hint_method(&input)?;
37     let name = input.ident;
38 
39     // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
40     let generics = apply_trait_bounds(
41         input.generics,
42         lifetime_without_bounds.clone(),
43         &container_attrs,
44     )?;
45 
46     // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
47     let mut generics_with_lifetime = generics.clone();
48     generics_with_lifetime
49         .params
50         .push(GenericParam::Lifetime(lifetime_with_bounds));
51     let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
52 
53     // Build TypeGenerics and WhereClause without a lifetime
54     let (_, ty_generics, where_clause) = generics.split_for_impl();
55 
56     Ok(quote! {
57         const _: () = {
58             std::thread_local! {
59                 #[allow(non_upper_case_globals)]
60                 static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
61             }
62 
63             #[automatically_derived]
64             impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
65                 #arbitrary_method
66                 #size_hint_method
67             }
68         };
69     })
70 }
71 
72 // Returns: (lifetime without bounds, lifetime with bounds)
73 // Example: ("'arbitrary", "'arbitrary: 'a + 'b")
build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam)74 fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
75     let lifetime_without_bounds =
76         LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
77     let mut lifetime_with_bounds = lifetime_without_bounds.clone();
78 
79     for param in generics.params.iter() {
80         if let GenericParam::Lifetime(lifetime_def) = param {
81             lifetime_with_bounds
82                 .bounds
83                 .push(lifetime_def.lifetime.clone());
84         }
85     }
86 
87     (lifetime_without_bounds, lifetime_with_bounds)
88 }
89 
apply_trait_bounds( mut generics: Generics, lifetime: LifetimeParam, container_attrs: &ContainerAttributes, ) -> Result<Generics>90 fn apply_trait_bounds(
91     mut generics: Generics,
92     lifetime: LifetimeParam,
93     container_attrs: &ContainerAttributes,
94 ) -> Result<Generics> {
95     // If user-supplied bounds exist, apply them to their matching type parameters.
96     if let Some(config_bounds) = &container_attrs.bounds {
97         let mut config_bounds_applied = 0;
98         for param in generics.params.iter_mut() {
99             if let GenericParam::Type(type_param) = param {
100                 if let Some(replacement) = config_bounds
101                     .iter()
102                     .flatten()
103                     .find(|p| p.ident == type_param.ident)
104                 {
105                     *type_param = replacement.clone();
106                     config_bounds_applied += 1;
107                 } else {
108                     // If no user-supplied bounds exist for this type, delete the original bounds.
109                     // This mimics serde.
110                     type_param.bounds = Default::default();
111                     type_param.default = None;
112                 }
113             }
114         }
115         let config_bounds_supplied = config_bounds
116             .iter()
117             .map(|bounds| bounds.len())
118             .sum::<usize>();
119         if config_bounds_applied != config_bounds_supplied {
120             return Err(Error::new(
121                 Span::call_site(),
122                 format!(
123                     "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
124                     ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
125                 ),
126             ));
127         }
128         Ok(generics)
129     } else {
130         // Otherwise, inject a `T: Arbitrary` bound for every parameter.
131         Ok(add_trait_bounds(generics, lifetime))
132     }
133 }
134 
135 // Add a bound `T: Arbitrary` to every type parameter T.
add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics136 fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
137     for param in generics.params.iter_mut() {
138         if let GenericParam::Type(type_param) = param {
139             type_param
140                 .bounds
141                 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
142         }
143     }
144     generics
145 }
146 
with_recursive_count_guard( recursive_count: &syn::Ident, expr: impl quote::ToTokens, ) -> impl quote::ToTokens147 fn with_recursive_count_guard(
148     recursive_count: &syn::Ident,
149     expr: impl quote::ToTokens,
150 ) -> impl quote::ToTokens {
151     quote! {
152         let guard_against_recursion = u.is_empty();
153         if guard_against_recursion {
154             #recursive_count.with(|count| {
155                 if count.get() > 0 {
156                     return Err(arbitrary::Error::NotEnoughData);
157                 }
158                 count.set(count.get() + 1);
159                 Ok(())
160             })?;
161         }
162 
163         let result = (|| { #expr })();
164 
165         if guard_against_recursion {
166             #recursive_count.with(|count| {
167                 count.set(count.get() - 1);
168             });
169         }
170 
171         result
172     }
173 }
174 
gen_arbitrary_method( input: &DeriveInput, lifetime: LifetimeParam, recursive_count: &syn::Ident, ) -> Result<TokenStream>175 fn gen_arbitrary_method(
176     input: &DeriveInput,
177     lifetime: LifetimeParam,
178     recursive_count: &syn::Ident,
179 ) -> Result<TokenStream> {
180     fn arbitrary_structlike(
181         fields: &Fields,
182         ident: &syn::Ident,
183         lifetime: LifetimeParam,
184         recursive_count: &syn::Ident,
185     ) -> Result<TokenStream> {
186         let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
187         let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
188 
189         let arbitrary_take_rest = construct_take_rest(fields)?;
190         let take_rest_body =
191             with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
192 
193         Ok(quote! {
194             fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
195                 #body
196             }
197 
198             fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
199                 #take_rest_body
200             }
201         })
202     }
203 
204     let ident = &input.ident;
205     let output = match &input.data {
206         Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
207         Data::Union(data) => arbitrary_structlike(
208             &Fields::Named(data.fields.clone()),
209             ident,
210             lifetime,
211             recursive_count,
212         )?,
213         Data::Enum(data) => {
214             let variants: Vec<TokenStream> = data
215                 .variants
216                 .iter()
217                 .enumerate()
218                 .map(|(i, variant)| {
219                     let idx = i as u64;
220                     let variant_name = &variant.ident;
221                     construct(&variant.fields, |_, field| gen_constructor_for_field(field))
222                         .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
223                 })
224                 .collect::<Result<_>>()?;
225 
226             let variants_take_rest: Vec<TokenStream> = data
227                 .variants
228                 .iter()
229                 .enumerate()
230                 .map(|(i, variant)| {
231                     let idx = i as u64;
232                     let variant_name = &variant.ident;
233                     construct_take_rest(&variant.fields)
234                         .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
235                 })
236                 .collect::<Result<_>>()?;
237 
238             let count = data.variants.len() as u64;
239 
240             let arbitrary = with_recursive_count_guard(
241                 recursive_count,
242                 quote! {
243                     // Use a multiply + shift to generate a ranged random number
244                     // with slight bias. For details, see:
245                     // https://lemire.me/blog/2016/06/30/fast-random-shuffling
246                     Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
247                         #(#variants,)*
248                         _ => unreachable!()
249                     })
250                 },
251             );
252 
253             let arbitrary_take_rest = with_recursive_count_guard(
254                 recursive_count,
255                 quote! {
256                     // Use a multiply + shift to generate a ranged random number
257                     // with slight bias. For details, see:
258                     // https://lemire.me/blog/2016/06/30/fast-random-shuffling
259                     Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
260                         #(#variants_take_rest,)*
261                         _ => unreachable!()
262                     })
263                 },
264             );
265 
266             quote! {
267                 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
268                     #arbitrary
269                 }
270 
271                 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
272                     #arbitrary_take_rest
273                 }
274             }
275         }
276     };
277     Ok(output)
278 }
279 
construct( fields: &Fields, ctor: impl Fn(usize, &Field) -> Result<TokenStream>, ) -> Result<TokenStream>280 fn construct(
281     fields: &Fields,
282     ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
283 ) -> Result<TokenStream> {
284     let output = match fields {
285         Fields::Named(names) => {
286             let names: Vec<TokenStream> = names
287                 .named
288                 .iter()
289                 .enumerate()
290                 .map(|(i, f)| {
291                     let name = f.ident.as_ref().unwrap();
292                     ctor(i, f).map(|ctor| quote! { #name: #ctor })
293                 })
294                 .collect::<Result<_>>()?;
295             quote! { { #(#names,)* } }
296         }
297         Fields::Unnamed(names) => {
298             let names: Vec<TokenStream> = names
299                 .unnamed
300                 .iter()
301                 .enumerate()
302                 .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
303                 .collect::<Result<_>>()?;
304             quote! { ( #(#names),* ) }
305         }
306         Fields::Unit => quote!(),
307     };
308     Ok(output)
309 }
310 
construct_take_rest(fields: &Fields) -> Result<TokenStream>311 fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
312     construct(fields, |idx, field| {
313         determine_field_constructor(field).map(|field_constructor| match field_constructor {
314             FieldConstructor::Default => quote!(Default::default()),
315             FieldConstructor::Arbitrary => {
316                 if idx + 1 == fields.len() {
317                     quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
318                 } else {
319                     quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
320                 }
321             }
322             FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
323             FieldConstructor::Value(value) => quote!(#value),
324         })
325     })
326 }
327 
gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream>328 fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
329     let size_hint_fields = |fields: &Fields| {
330         fields
331             .iter()
332             .map(|f| {
333                 let ty = &f.ty;
334                 determine_field_constructor(f).map(|field_constructor| {
335                     match field_constructor {
336                         FieldConstructor::Default | FieldConstructor::Value(_) => {
337                             quote!((0, Some(0)))
338                         }
339                         FieldConstructor::Arbitrary => {
340                             quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
341                         }
342 
343                         // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
344                         // just an educated guess, although it's gonna be inaccurate for dynamically
345                         // allocated types (Vec, HashMap, etc.).
346                         FieldConstructor::With(_) => {
347                             quote! { (::core::mem::size_of::<#ty>(), None) }
348                         }
349                     }
350                 })
351             })
352             .collect::<Result<Vec<TokenStream>>>()
353             .map(|hints| {
354                 quote! {
355                     arbitrary::size_hint::and_all(&[
356                         #( #hints ),*
357                     ])
358                 }
359             })
360     };
361     let size_hint_structlike = |fields: &Fields| {
362         size_hint_fields(fields).map(|hint| {
363             quote! {
364                 #[inline]
365                 fn size_hint(depth: usize) -> (usize, Option<usize>) {
366                     arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
367                 }
368             }
369         })
370     };
371     match &input.data {
372         Data::Struct(data) => size_hint_structlike(&data.fields),
373         Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
374         Data::Enum(data) => data
375             .variants
376             .iter()
377             .map(|v| size_hint_fields(&v.fields))
378             .collect::<Result<Vec<TokenStream>>>()
379             .map(|variants| {
380                 quote! {
381                     #[inline]
382                     fn size_hint(depth: usize) -> (usize, Option<usize>) {
383                         arbitrary::size_hint::and(
384                             <u32 as arbitrary::Arbitrary>::size_hint(depth),
385                             arbitrary::size_hint::recursion_guard(depth, |depth| {
386                                 arbitrary::size_hint::or_all(&[ #( #variants ),* ])
387                             }),
388                         )
389                     }
390                 }
391             }),
392     }
393 }
394 
gen_constructor_for_field(field: &Field) -> Result<TokenStream>395 fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
396     let ctor = match determine_field_constructor(field)? {
397         FieldConstructor::Default => quote!(Default::default()),
398         FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
399         FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
400         FieldConstructor::Value(value) => quote!(#value),
401     };
402     Ok(ctor)
403 }
404