1 use crate::enum_attributes::ErrorTypeAttribute;
2 use crate::utils::die;
3 use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes};
4 use proc_macro2::Span;
5 use quote::{format_ident, ToTokens};
6 use std::collections::BTreeSet;
7 use syn::{
8     parse::{Parse, ParseStream},
9     parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit,
10     LitInt, Meta, Path, Result, UnOp,
11 };
12 
13 pub(crate) struct EnumInfo {
14     pub(crate) name: Ident,
15     pub(crate) repr: Ident,
16     pub(crate) variants: Vec<VariantInfo>,
17     pub(crate) error_type_info: ErrorType,
18 }
19 
20 impl EnumInfo {
21     /// Returns whether the number of variants (ignoring defaults, catch-alls, etc) is the same as
22     /// the capacity of the repr.
is_naturally_exhaustive(&self) -> Result<bool>23     pub(crate) fn is_naturally_exhaustive(&self) -> Result<bool> {
24         let repr_str = self.repr.to_string();
25         if !repr_str.is_empty() {
26             let suffix = repr_str
27                 .strip_prefix('i')
28                 .or_else(|| repr_str.strip_prefix('u'));
29             if let Some(suffix) = suffix {
30                 if suffix == "size" {
31                     return Ok(false);
32                 } else if let Ok(bits) = suffix.parse::<u32>() {
33                     let variants = 1usize.checked_shl(bits);
34                     return Ok(variants.map_or(false, |v| {
35                         v == self
36                             .variants
37                             .iter()
38                             .map(|v| v.alternative_values.len() + 1)
39                             .sum()
40                     }));
41                 }
42             }
43         }
44         die!(self.repr.clone() => "Failed to parse repr into bit size");
45     }
46 
default(&self) -> Option<&Ident>47     pub(crate) fn default(&self) -> Option<&Ident> {
48         self.variants
49             .iter()
50             .find(|info| info.is_default)
51             .map(|info| &info.ident)
52     }
53 
catch_all(&self) -> Option<&Ident>54     pub(crate) fn catch_all(&self) -> Option<&Ident> {
55         self.variants
56             .iter()
57             .find(|info| info.is_catch_all)
58             .map(|info| &info.ident)
59     }
60 
variant_idents(&self) -> Vec<Ident>61     pub(crate) fn variant_idents(&self) -> Vec<Ident> {
62         self.variants
63             .iter()
64             .filter(|variant| !variant.is_catch_all)
65             .map(|variant| variant.ident.clone())
66             .collect()
67     }
68 
expression_idents(&self) -> Vec<Vec<Ident>>69     pub(crate) fn expression_idents(&self) -> Vec<Vec<Ident>> {
70         self.variants
71             .iter()
72             .filter(|variant| !variant.is_catch_all)
73             .map(|info| {
74                 let indices = 0..(info.alternative_values.len() + 1);
75                 indices
76                     .map(|index| format_ident!("{}__num_enum_{}__", info.ident, index))
77                     .collect()
78             })
79             .collect()
80     }
81 
variant_expressions(&self) -> Vec<Vec<Expr>>82     pub(crate) fn variant_expressions(&self) -> Vec<Vec<Expr>> {
83         self.variants
84             .iter()
85             .filter(|variant| !variant.is_catch_all)
86             .map(|variant| variant.all_values().cloned().collect())
87             .collect()
88     }
89 
parse_attrs<Attrs: Iterator<Item = Attribute>>( attrs: Attrs, ) -> Result<(Ident, Option<ErrorType>)>90     fn parse_attrs<Attrs: Iterator<Item = Attribute>>(
91         attrs: Attrs,
92     ) -> Result<(Ident, Option<ErrorType>)> {
93         let mut maybe_repr = None;
94         let mut maybe_error_type = None;
95         for attr in attrs {
96             if let Meta::List(meta_list) = &attr.meta {
97                 if let Some(ident) = meta_list.path.get_ident() {
98                     if ident == "repr" {
99                         let mut nested = meta_list.tokens.clone().into_iter();
100                         let repr_tree = match (nested.next(), nested.next()) {
101                             (Some(repr_tree), None) => repr_tree,
102                             _ => die!(attr =>
103                                 "Expected exactly one `repr` argument"
104                             ),
105                         };
106                         let repr_ident: Ident = parse_quote! {
107                             #repr_tree
108                         };
109                         if repr_ident == "C" {
110                             die!(repr_ident =>
111                                 "repr(C) doesn't have a well defined size"
112                             );
113                         } else {
114                             maybe_repr = Some(repr_ident);
115                         }
116                     } else if ident == "num_enum" {
117                         let attributes =
118                             attr.parse_args_with(crate::enum_attributes::Attributes::parse)?;
119                         if let Some(error_type) = attributes.error_type {
120                             if maybe_error_type.is_some() {
121                                 die!(attr => "At most one num_enum error_type attribute may be specified");
122                             }
123                             maybe_error_type = Some(error_type.into());
124                         }
125                     }
126                 }
127             }
128         }
129         if maybe_repr.is_none() {
130             die!("Missing `#[repr({Integer})]` attribute");
131         }
132         Ok((maybe_repr.unwrap(), maybe_error_type))
133     }
134 }
135 
136 impl Parse for EnumInfo {
parse(input: ParseStream) -> Result<Self>137     fn parse(input: ParseStream) -> Result<Self> {
138         Ok({
139             let input: DeriveInput = input.parse()?;
140             let name = input.ident;
141             let data = match input.data {
142                 Data::Enum(data) => data,
143                 Data::Union(data) => die!(data.union_token => "Expected enum but found union"),
144                 Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"),
145             };
146 
147             let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?;
148 
149             let mut variants: Vec<VariantInfo> = vec![];
150             let mut has_default_variant: bool = false;
151             let mut has_catch_all_variant: bool = false;
152 
153             // Vec to keep track of the used discriminants and alt values.
154             let mut discriminant_int_val_set = BTreeSet::new();
155 
156             let mut next_discriminant = literal(0);
157             for variant in data.variants.into_iter() {
158                 let ident = variant.ident.clone();
159 
160                 let discriminant = match &variant.discriminant {
161                     Some(d) => d.1.clone(),
162                     None => next_discriminant.clone(),
163                 };
164 
165                 let mut raw_alternative_values: Vec<Expr> = vec![];
166                 // Keep the attribute around for better error reporting.
167                 let mut alt_attr_ref: Vec<&Attribute> = vec![];
168 
169                 // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
170                 // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
171                 // keep track of whether we encountered such an attribute:
172                 let mut is_default: bool = false;
173                 let mut is_catch_all: bool = false;
174 
175                 for attribute in &variant.attrs {
176                     if attribute.path().is_ident("default") {
177                         if has_default_variant {
178                             die!(attribute =>
179                                 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
180                             );
181                         } else if has_catch_all_variant {
182                             die!(attribute =>
183                                 "Attribute `default` is mutually exclusive with `catch_all`"
184                             );
185                         }
186                         is_default = true;
187                         has_default_variant = true;
188                     }
189 
190                     if attribute.path().is_ident("num_enum") {
191                         match attribute.parse_args_with(NumEnumVariantAttributes::parse) {
192                             Ok(variant_attributes) => {
193                                 for variant_attribute in variant_attributes.items {
194                                     match variant_attribute {
195                                         NumEnumVariantAttributeItem::Default(default) => {
196                                             if has_default_variant {
197                                                 die!(default.keyword =>
198                                                     "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
199                                                 );
200                                             } else if has_catch_all_variant {
201                                                 die!(default.keyword =>
202                                                     "Attribute `default` is mutually exclusive with `catch_all`"
203                                                 );
204                                             }
205                                             is_default = true;
206                                             has_default_variant = true;
207                                         }
208                                         NumEnumVariantAttributeItem::CatchAll(catch_all) => {
209                                             if has_catch_all_variant {
210                                                 die!(catch_all.keyword =>
211                                                     "Multiple variants marked with `#[num_enum(catch_all)]`"
212                                                 );
213                                             } else if has_default_variant {
214                                                 die!(catch_all.keyword =>
215                                                     "Attribute `catch_all` is mutually exclusive with `default`"
216                                                 );
217                                             }
218 
219                                             match variant
220                                                 .fields
221                                                 .iter()
222                                                 .collect::<Vec<_>>()
223                                                 .as_slice()
224                                             {
225                                                 [syn::Field {
226                                                     ty: syn::Type::Path(syn::TypePath { path, .. }),
227                                                     ..
228                                                 }] if path.is_ident(&repr) => {
229                                                     is_catch_all = true;
230                                                     has_catch_all_variant = true;
231                                                 }
232                                                 _ => {
233                                                     die!(catch_all.keyword =>
234                                                         "Variant with `catch_all` must be a tuple with exactly 1 field matching the repr type"
235                                                     );
236                                                 }
237                                             }
238                                         }
239                                         NumEnumVariantAttributeItem::Alternatives(alternatives) => {
240                                             raw_alternative_values.extend(alternatives.expressions);
241                                             alt_attr_ref.push(attribute);
242                                         }
243                                     }
244                                 }
245                             }
246                             Err(err) => {
247                                 if cfg!(not(feature = "complex-expressions")) {
248                                     let tokens = attribute.meta.to_token_stream();
249 
250                                     let attribute_str = format!("{}", tokens);
251                                     if attribute_str.contains("alternatives")
252                                         && attribute_str.contains("..")
253                                     {
254                                         // Give a nice error message suggesting how to fix the problem.
255                                         die!(attribute => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
256                                     }
257                                 }
258                                 die!(attribute =>
259                                     format!("Invalid attribute: {}", err)
260                                 );
261                             }
262                         }
263                     }
264                 }
265 
266                 if !is_catch_all {
267                     match &variant.fields {
268                         Fields::Named(_) | Fields::Unnamed(_) => {
269                             die!(variant => format!("`{}` only supports unit variants (with no associated data), but `{}::{}` was not a unit variant.", get_crate_name(), name, ident));
270                         }
271                         Fields::Unit => {}
272                     }
273                 }
274 
275                 let discriminant_value = parse_discriminant(&discriminant)?;
276 
277                 // Check for collision.
278                 // We can't do const evaluation, or even compare arbitrary Exprs,
279                 // so unfortunately we can't check for duplicates.
280                 // That's not the end of the world, just we'll end up with compile errors for
281                 // matches with duplicate branches in generated code instead of nice friendly error messages.
282                 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
283                     if discriminant_int_val_set.contains(&canonical_value_int) {
284                         die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
285                     }
286                 }
287 
288                 // Deal with the alternative values.
289                 let mut flattened_alternative_values = Vec::new();
290                 let mut flattened_raw_alternative_values = Vec::new();
291                 for raw_alternative_value in raw_alternative_values {
292                     let expanded_values = parse_alternative_values(&raw_alternative_value)?;
293                     for expanded_value in expanded_values {
294                         flattened_alternative_values.push(expanded_value);
295                         flattened_raw_alternative_values.push(raw_alternative_value.clone())
296                     }
297                 }
298 
299                 if !flattened_alternative_values.is_empty() {
300                     let alternate_int_values = flattened_alternative_values
301                         .into_iter()
302                         .map(|v| {
303                             match v {
304                                 DiscriminantValue::Literal(value) => Ok(value),
305                                 DiscriminantValue::Expr(expr) => {
306                                     if let Expr::Range(_) = expr {
307                                         if cfg!(not(feature = "complex-expressions")) {
308                                             // Give a nice error message suggesting how to fix the problem.
309                                             die!(expr => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
310                                         }
311                                     }
312                                     // We can't do uniqueness checking on non-literals, so we don't allow them as alternate values.
313                                     // We could probably allow them, but there doesn't seem to be much of a use-case,
314                                     // and it's easier to give good error messages about duplicate values this way,
315                                     // rather than rustc errors on conflicting match branches.
316                                     die!(expr => "Only literals are allowed as num_enum alternate values".to_string())
317                                 },
318                             }
319                         })
320                         .collect::<Result<Vec<i128>>>()?;
321                     let mut sorted_alternate_int_values = alternate_int_values.clone();
322                     sorted_alternate_int_values.sort_unstable();
323                     let sorted_alternate_int_values = sorted_alternate_int_values;
324 
325                     // Check if the current discriminant is not in the alternative values.
326                     if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
327                         if let Some(index) = alternate_int_values
328                             .iter()
329                             .position(|&x| x == canonical_value_int)
330                         {
331                             die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
332                         }
333                     }
334 
335                     // Search for duplicates, the vec is sorted. Warn about them.
336                     if (1..sorted_alternate_int_values.len()).any(|i| {
337                         sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
338                     }) {
339                         let attr = *alt_attr_ref.last().unwrap();
340                         die!(attr => "There is duplication in the alternative values");
341                     }
342                     // Search if those discriminant_int_val_set where already attributed.
343                     // (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
344                     if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
345                         if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
346                             for (index, val) in alternate_int_values.iter().enumerate() {
347                                 if discriminant_int_val_set.contains(val) {
348                                     die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
349                                 }
350                             }
351                         }
352                     }
353 
354                     // Reconstruct the alternative_values vec of Expr but sorted.
355                     flattened_raw_alternative_values = sorted_alternate_int_values
356                         .iter()
357                         .map(|val| literal(val.to_owned()))
358                         .collect();
359 
360                     // Add the alternative values to the the set to keep track.
361                     discriminant_int_val_set.extend(sorted_alternate_int_values);
362                 }
363 
364                 // Add the current discriminant to the the set to keep track.
365                 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
366                     discriminant_int_val_set.insert(canonical_value_int);
367                 }
368 
369                 variants.push(VariantInfo {
370                     ident,
371                     is_default,
372                     is_catch_all,
373                     canonical_value: discriminant,
374                     alternative_values: flattened_raw_alternative_values,
375                 });
376 
377                 // Get the next value for the discriminant.
378                 next_discriminant = match discriminant_value {
379                     DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)),
380                     DiscriminantValue::Expr(expr) => {
381                         parse_quote! {
382                             #repr::wrapping_add(#expr, 1)
383                         }
384                     }
385                 }
386             }
387 
388             let error_type_info = maybe_error_type.unwrap_or_else(|| {
389                 let crate_name = Ident::new(&get_crate_name(), Span::call_site());
390                 ErrorType {
391                     name: parse_quote! {
392                         ::#crate_name::TryFromPrimitiveError<Self>
393                     },
394                     constructor: parse_quote! {
395                         ::#crate_name::TryFromPrimitiveError::<Self>::new
396                     },
397                 }
398             });
399 
400             EnumInfo {
401                 name,
402                 repr,
403                 variants,
404                 error_type_info,
405             }
406         })
407     }
408 }
409 
literal(i: i128) -> Expr410 fn literal(i: i128) -> Expr {
411     Expr::Lit(ExprLit {
412         lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())),
413         attrs: vec![],
414     })
415 }
416 
417 enum DiscriminantValue {
418     Literal(i128),
419     Expr(Expr),
420 }
421 
parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue>422 fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
423     let mut sign = 1;
424     let mut unsigned_expr = val_exp;
425     if let Expr::Unary(ExprUnary {
426         op: UnOp::Neg(..),
427         expr,
428         ..
429     }) = val_exp
430     {
431         unsigned_expr = expr;
432         sign = -1;
433     }
434     if let Expr::Lit(ExprLit {
435         lit: Lit::Int(ref lit_int),
436         ..
437     }) = unsigned_expr
438     {
439         Ok(DiscriminantValue::Literal(
440             sign * lit_int.base10_parse::<i128>()?,
441         ))
442     } else {
443         Ok(DiscriminantValue::Expr(val_exp.clone()))
444     }
445 }
446 
447 #[cfg(feature = "complex-expressions")]
parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>>448 fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
449     fn range_expr_value_to_number(
450         parent_range_expr: &Expr,
451         range_bound_value: &Option<Box<Expr>>,
452     ) -> Result<i128> {
453         // Avoid needing to calculate what the lower and upper bound would be - these are type dependent,
454         // and also may not be obvious in context (e.g. an omitted bound could reasonably mean "from the last discriminant" or "from the lower bound of the type").
455         if let Some(range_bound_value) = range_bound_value {
456             let range_bound_value = parse_discriminant(range_bound_value.as_ref())?;
457             // If non-literals are used, we can't expand to the mapped values, so can't write a nice match statement or do exhaustiveness checking.
458             // Require literals instead.
459             if let DiscriminantValue::Literal(value) = range_bound_value {
460                 return Ok(value);
461             }
462         }
463         die!(parent_range_expr => "When ranges are used for alternate values, both bounds most be explicitly specified numeric literals")
464     }
465 
466     if let Expr::Range(syn::ExprRange {
467         start, end, limits, ..
468     }) = val_expr
469     {
470         let lower = range_expr_value_to_number(val_expr, start)?;
471         let upper = range_expr_value_to_number(val_expr, end)?;
472         // While this is technically allowed in Rust, and results in an empty range, it's almost certainly a mistake in this context.
473         if lower > upper {
474             die!(val_expr => "When using ranges for alternate values, upper bound must not be less than lower bound");
475         }
476         let mut values = Vec::with_capacity((upper - lower) as usize);
477         let mut next = lower;
478         loop {
479             match limits {
480                 syn::RangeLimits::HalfOpen(..) => {
481                     if next == upper {
482                         break;
483                     }
484                 }
485                 syn::RangeLimits::Closed(..) => {
486                     if next > upper {
487                         break;
488                     }
489                 }
490             }
491             values.push(DiscriminantValue::Literal(next));
492             next += 1;
493         }
494         return Ok(values);
495     }
496     parse_discriminant(val_expr).map(|v| vec![v])
497 }
498 
499 #[cfg(not(feature = "complex-expressions"))]
parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>>500 fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
501     parse_discriminant(val_expr).map(|v| vec![v])
502 }
503 
504 pub(crate) struct VariantInfo {
505     ident: Ident,
506     is_default: bool,
507     is_catch_all: bool,
508     canonical_value: Expr,
509     alternative_values: Vec<Expr>,
510 }
511 
512 impl VariantInfo {
all_values(&self) -> impl Iterator<Item = &Expr>513     fn all_values(&self) -> impl Iterator<Item = &Expr> {
514         ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter())
515     }
516 }
517 
518 pub(crate) struct ErrorType {
519     pub(crate) name: Path,
520     pub(crate) constructor: Path,
521 }
522 
523 impl From<ErrorTypeAttribute> for ErrorType {
from(attribute: ErrorTypeAttribute) -> Self524     fn from(attribute: ErrorTypeAttribute) -> Self {
525         Self {
526             name: attribute.name.path,
527             constructor: attribute.constructor.path,
528         }
529     }
530 }
531 
532 #[cfg(feature = "proc-macro-crate")]
get_crate_name() -> String533 pub(crate) fn get_crate_name() -> String {
534     let found_crate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err| {
535         eprintln!("Warning: {}\n    => defaulting to `num_enum`", err,);
536         proc_macro_crate::FoundCrate::Itself
537     });
538 
539     match found_crate {
540         proc_macro_crate::FoundCrate::Itself => String::from("num_enum"),
541         proc_macro_crate::FoundCrate::Name(name) => name,
542     }
543 }
544 
545 // Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency
546 // on serde with std.
547 //
548 // no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry.
549 //
550 // See https://github.com/illicitonion/num_enum/issues/18
551 #[cfg(not(feature = "proc-macro-crate"))]
get_crate_name() -> String552 pub(crate) fn get_crate_name() -> String {
553     String::from("num_enum")
554 }
555