1 use crate::ARBITRARY_ATTRIBUTE_NAME;
2 use proc_macro2::{Span, TokenStream, TokenTree};
3 use quote::quote;
4 use syn::{spanned::Spanned, *};
5 
6 /// Determines how a value for a field should be constructed.
7 #[cfg_attr(test, derive(Debug))]
8 pub enum FieldConstructor {
9     /// Assume that Arbitrary is defined for the type of this field and use it (default)
10     Arbitrary,
11 
12     /// Places `Default::default()` as a field value.
13     Default,
14 
15     /// Use custom function or closure to generate a value for a field.
16     With(TokenStream),
17 
18     /// Set a field always to the given value.
19     Value(TokenStream),
20 }
21 
determine_field_constructor(field: &Field) -> Result<FieldConstructor>22 pub fn determine_field_constructor(field: &Field) -> Result<FieldConstructor> {
23     let opt_attr = fetch_attr_from_field(field)?;
24     let ctor = match opt_attr {
25         Some(attr) => parse_attribute(attr)?,
26         None => FieldConstructor::Arbitrary,
27     };
28     Ok(ctor)
29 }
30 
fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>>31 fn fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>> {
32     let found_attributes: Vec<_> = field
33         .attrs
34         .iter()
35         .filter(|a| {
36             let path = a.path();
37             let name = quote!(#path).to_string();
38             name == ARBITRARY_ATTRIBUTE_NAME
39         })
40         .collect();
41     if found_attributes.len() > 1 {
42         let name = field.ident.as_ref().unwrap();
43         let msg = format!(
44             "Multiple conflicting #[{ARBITRARY_ATTRIBUTE_NAME}] attributes found on field `{name}`"
45         );
46         return Err(syn::Error::new(field.span(), msg));
47     }
48     Ok(found_attributes.into_iter().next())
49 }
50 
parse_attribute(attr: &Attribute) -> Result<FieldConstructor>51 fn parse_attribute(attr: &Attribute) -> Result<FieldConstructor> {
52     if let Meta::List(ref meta_list) = attr.meta {
53         parse_attribute_internals(meta_list)
54     } else {
55         let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] must contain a group");
56         Err(syn::Error::new(attr.span(), msg))
57     }
58 }
59 
parse_attribute_internals(meta_list: &MetaList) -> Result<FieldConstructor>60 fn parse_attribute_internals(meta_list: &MetaList) -> Result<FieldConstructor> {
61     let mut tokens_iter = meta_list.tokens.clone().into_iter();
62     let token = tokens_iter.next().ok_or_else(|| {
63         let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
64         syn::Error::new(meta_list.span(), msg)
65     })?;
66     match token.to_string().as_ref() {
67         "default" => Ok(FieldConstructor::Default),
68         "with" => {
69             let func_path = parse_assigned_value("with", tokens_iter, meta_list.span())?;
70             Ok(FieldConstructor::With(func_path))
71         }
72         "value" => {
73             let value = parse_assigned_value("value", tokens_iter, meta_list.span())?;
74             Ok(FieldConstructor::Value(value))
75         }
76         _ => {
77             let msg = format!("Unknown option for #[{ARBITRARY_ATTRIBUTE_NAME}]: `{token}`");
78             Err(syn::Error::new(token.span(), msg))
79         }
80     }
81 }
82 
83 // Input:
84 //     = 2 + 2
85 // Output:
86 //     2 + 2
parse_assigned_value( opt_name: &str, mut tokens_iter: impl Iterator<Item = TokenTree>, default_span: Span, ) -> Result<TokenStream>87 fn parse_assigned_value(
88     opt_name: &str,
89     mut tokens_iter: impl Iterator<Item = TokenTree>,
90     default_span: Span,
91 ) -> Result<TokenStream> {
92     let eq_sign = tokens_iter.next().ok_or_else(|| {
93         let msg = format!(
94             "Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], `{opt_name}` is missing assignment."
95         );
96         syn::Error::new(default_span, msg)
97     })?;
98 
99     if eq_sign.to_string() == "=" {
100         Ok(tokens_iter.collect())
101     } else {
102         let msg = format!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], expected `=` after `{opt_name}`, got: `{eq_sign}`");
103         Err(syn::Error::new(eq_sign.span(), msg))
104     }
105 }
106