1 use crate::enum_attributes::ErrorTypeAttribute;
2 use crate::utils::die;
3 use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes};
4 use proc_macro2::Span;
5 use quote::{format_ident, ToTokens};
6 use std::collections::BTreeSet;
7 use syn::{
8 parse::{Parse, ParseStream},
9 parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit,
10 LitInt, Meta, Path, Result, UnOp,
11 };
12
13 pub(crate) struct EnumInfo {
14 pub(crate) name: Ident,
15 pub(crate) repr: Ident,
16 pub(crate) variants: Vec<VariantInfo>,
17 pub(crate) error_type_info: ErrorType,
18 }
19
20 impl EnumInfo {
21 /// Returns whether the number of variants (ignoring defaults, catch-alls, etc) is the same as
22 /// the capacity of the repr.
is_naturally_exhaustive(&self) -> Result<bool>23 pub(crate) fn is_naturally_exhaustive(&self) -> Result<bool> {
24 let repr_str = self.repr.to_string();
25 if !repr_str.is_empty() {
26 let suffix = repr_str
27 .strip_prefix('i')
28 .or_else(|| repr_str.strip_prefix('u'));
29 if let Some(suffix) = suffix {
30 if suffix == "size" {
31 return Ok(false);
32 } else if let Ok(bits) = suffix.parse::<u32>() {
33 let variants = 1usize.checked_shl(bits);
34 return Ok(variants.map_or(false, |v| {
35 v == self
36 .variants
37 .iter()
38 .map(|v| v.alternative_values.len() + 1)
39 .sum()
40 }));
41 }
42 }
43 }
44 die!(self.repr.clone() => "Failed to parse repr into bit size");
45 }
46
default(&self) -> Option<&Ident>47 pub(crate) fn default(&self) -> Option<&Ident> {
48 self.variants
49 .iter()
50 .find(|info| info.is_default)
51 .map(|info| &info.ident)
52 }
53
catch_all(&self) -> Option<&Ident>54 pub(crate) fn catch_all(&self) -> Option<&Ident> {
55 self.variants
56 .iter()
57 .find(|info| info.is_catch_all)
58 .map(|info| &info.ident)
59 }
60
variant_idents(&self) -> Vec<Ident>61 pub(crate) fn variant_idents(&self) -> Vec<Ident> {
62 self.variants
63 .iter()
64 .filter(|variant| !variant.is_catch_all)
65 .map(|variant| variant.ident.clone())
66 .collect()
67 }
68
expression_idents(&self) -> Vec<Vec<Ident>>69 pub(crate) fn expression_idents(&self) -> Vec<Vec<Ident>> {
70 self.variants
71 .iter()
72 .filter(|variant| !variant.is_catch_all)
73 .map(|info| {
74 let indices = 0..(info.alternative_values.len() + 1);
75 indices
76 .map(|index| format_ident!("{}__num_enum_{}__", info.ident, index))
77 .collect()
78 })
79 .collect()
80 }
81
variant_expressions(&self) -> Vec<Vec<Expr>>82 pub(crate) fn variant_expressions(&self) -> Vec<Vec<Expr>> {
83 self.variants
84 .iter()
85 .filter(|variant| !variant.is_catch_all)
86 .map(|variant| variant.all_values().cloned().collect())
87 .collect()
88 }
89
parse_attrs<Attrs: Iterator<Item = Attribute>>( attrs: Attrs, ) -> Result<(Ident, Option<ErrorType>)>90 fn parse_attrs<Attrs: Iterator<Item = Attribute>>(
91 attrs: Attrs,
92 ) -> Result<(Ident, Option<ErrorType>)> {
93 let mut maybe_repr = None;
94 let mut maybe_error_type = None;
95 for attr in attrs {
96 if let Meta::List(meta_list) = &attr.meta {
97 if let Some(ident) = meta_list.path.get_ident() {
98 if ident == "repr" {
99 let mut nested = meta_list.tokens.clone().into_iter();
100 let repr_tree = match (nested.next(), nested.next()) {
101 (Some(repr_tree), None) => repr_tree,
102 _ => die!(attr =>
103 "Expected exactly one `repr` argument"
104 ),
105 };
106 let repr_ident: Ident = parse_quote! {
107 #repr_tree
108 };
109 if repr_ident == "C" {
110 die!(repr_ident =>
111 "repr(C) doesn't have a well defined size"
112 );
113 } else {
114 maybe_repr = Some(repr_ident);
115 }
116 } else if ident == "num_enum" {
117 let attributes =
118 attr.parse_args_with(crate::enum_attributes::Attributes::parse)?;
119 if let Some(error_type) = attributes.error_type {
120 if maybe_error_type.is_some() {
121 die!(attr => "At most one num_enum error_type attribute may be specified");
122 }
123 maybe_error_type = Some(error_type.into());
124 }
125 }
126 }
127 }
128 }
129 if maybe_repr.is_none() {
130 die!("Missing `#[repr({Integer})]` attribute");
131 }
132 Ok((maybe_repr.unwrap(), maybe_error_type))
133 }
134 }
135
136 impl Parse for EnumInfo {
parse(input: ParseStream) -> Result<Self>137 fn parse(input: ParseStream) -> Result<Self> {
138 Ok({
139 let input: DeriveInput = input.parse()?;
140 let name = input.ident;
141 let data = match input.data {
142 Data::Enum(data) => data,
143 Data::Union(data) => die!(data.union_token => "Expected enum but found union"),
144 Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"),
145 };
146
147 let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?;
148
149 let mut variants: Vec<VariantInfo> = vec![];
150 let mut has_default_variant: bool = false;
151 let mut has_catch_all_variant: bool = false;
152
153 // Vec to keep track of the used discriminants and alt values.
154 let mut discriminant_int_val_set = BTreeSet::new();
155
156 let mut next_discriminant = literal(0);
157 for variant in data.variants.into_iter() {
158 let ident = variant.ident.clone();
159
160 let discriminant = match &variant.discriminant {
161 Some(d) => d.1.clone(),
162 None => next_discriminant.clone(),
163 };
164
165 let mut raw_alternative_values: Vec<Expr> = vec![];
166 // Keep the attribute around for better error reporting.
167 let mut alt_attr_ref: Vec<&Attribute> = vec![];
168
169 // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
170 // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
171 // keep track of whether we encountered such an attribute:
172 let mut is_default: bool = false;
173 let mut is_catch_all: bool = false;
174
175 for attribute in &variant.attrs {
176 if attribute.path().is_ident("default") {
177 if has_default_variant {
178 die!(attribute =>
179 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
180 );
181 } else if has_catch_all_variant {
182 die!(attribute =>
183 "Attribute `default` is mutually exclusive with `catch_all`"
184 );
185 }
186 is_default = true;
187 has_default_variant = true;
188 }
189
190 if attribute.path().is_ident("num_enum") {
191 match attribute.parse_args_with(NumEnumVariantAttributes::parse) {
192 Ok(variant_attributes) => {
193 for variant_attribute in variant_attributes.items {
194 match variant_attribute {
195 NumEnumVariantAttributeItem::Default(default) => {
196 if has_default_variant {
197 die!(default.keyword =>
198 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
199 );
200 } else if has_catch_all_variant {
201 die!(default.keyword =>
202 "Attribute `default` is mutually exclusive with `catch_all`"
203 );
204 }
205 is_default = true;
206 has_default_variant = true;
207 }
208 NumEnumVariantAttributeItem::CatchAll(catch_all) => {
209 if has_catch_all_variant {
210 die!(catch_all.keyword =>
211 "Multiple variants marked with `#[num_enum(catch_all)]`"
212 );
213 } else if has_default_variant {
214 die!(catch_all.keyword =>
215 "Attribute `catch_all` is mutually exclusive with `default`"
216 );
217 }
218
219 match variant
220 .fields
221 .iter()
222 .collect::<Vec<_>>()
223 .as_slice()
224 {
225 [syn::Field {
226 ty: syn::Type::Path(syn::TypePath { path, .. }),
227 ..
228 }] if path.is_ident(&repr) => {
229 is_catch_all = true;
230 has_catch_all_variant = true;
231 }
232 _ => {
233 die!(catch_all.keyword =>
234 "Variant with `catch_all` must be a tuple with exactly 1 field matching the repr type"
235 );
236 }
237 }
238 }
239 NumEnumVariantAttributeItem::Alternatives(alternatives) => {
240 raw_alternative_values.extend(alternatives.expressions);
241 alt_attr_ref.push(attribute);
242 }
243 }
244 }
245 }
246 Err(err) => {
247 if cfg!(not(feature = "complex-expressions")) {
248 let tokens = attribute.meta.to_token_stream();
249
250 let attribute_str = format!("{}", tokens);
251 if attribute_str.contains("alternatives")
252 && attribute_str.contains("..")
253 {
254 // Give a nice error message suggesting how to fix the problem.
255 die!(attribute => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
256 }
257 }
258 die!(attribute =>
259 format!("Invalid attribute: {}", err)
260 );
261 }
262 }
263 }
264 }
265
266 if !is_catch_all {
267 match &variant.fields {
268 Fields::Named(_) | Fields::Unnamed(_) => {
269 die!(variant => format!("`{}` only supports unit variants (with no associated data), but `{}::{}` was not a unit variant.", get_crate_name(), name, ident));
270 }
271 Fields::Unit => {}
272 }
273 }
274
275 let discriminant_value = parse_discriminant(&discriminant)?;
276
277 // Check for collision.
278 // We can't do const evaluation, or even compare arbitrary Exprs,
279 // so unfortunately we can't check for duplicates.
280 // That's not the end of the world, just we'll end up with compile errors for
281 // matches with duplicate branches in generated code instead of nice friendly error messages.
282 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
283 if discriminant_int_val_set.contains(&canonical_value_int) {
284 die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
285 }
286 }
287
288 // Deal with the alternative values.
289 let mut flattened_alternative_values = Vec::new();
290 let mut flattened_raw_alternative_values = Vec::new();
291 for raw_alternative_value in raw_alternative_values {
292 let expanded_values = parse_alternative_values(&raw_alternative_value)?;
293 for expanded_value in expanded_values {
294 flattened_alternative_values.push(expanded_value);
295 flattened_raw_alternative_values.push(raw_alternative_value.clone())
296 }
297 }
298
299 if !flattened_alternative_values.is_empty() {
300 let alternate_int_values = flattened_alternative_values
301 .into_iter()
302 .map(|v| {
303 match v {
304 DiscriminantValue::Literal(value) => Ok(value),
305 DiscriminantValue::Expr(expr) => {
306 if let Expr::Range(_) = expr {
307 if cfg!(not(feature = "complex-expressions")) {
308 // Give a nice error message suggesting how to fix the problem.
309 die!(expr => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
310 }
311 }
312 // We can't do uniqueness checking on non-literals, so we don't allow them as alternate values.
313 // We could probably allow them, but there doesn't seem to be much of a use-case,
314 // and it's easier to give good error messages about duplicate values this way,
315 // rather than rustc errors on conflicting match branches.
316 die!(expr => "Only literals are allowed as num_enum alternate values".to_string())
317 },
318 }
319 })
320 .collect::<Result<Vec<i128>>>()?;
321 let mut sorted_alternate_int_values = alternate_int_values.clone();
322 sorted_alternate_int_values.sort_unstable();
323 let sorted_alternate_int_values = sorted_alternate_int_values;
324
325 // Check if the current discriminant is not in the alternative values.
326 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
327 if let Some(index) = alternate_int_values
328 .iter()
329 .position(|&x| x == canonical_value_int)
330 {
331 die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
332 }
333 }
334
335 // Search for duplicates, the vec is sorted. Warn about them.
336 if (1..sorted_alternate_int_values.len()).any(|i| {
337 sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
338 }) {
339 let attr = *alt_attr_ref.last().unwrap();
340 die!(attr => "There is duplication in the alternative values");
341 }
342 // Search if those discriminant_int_val_set where already attributed.
343 // (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
344 if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
345 if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
346 for (index, val) in alternate_int_values.iter().enumerate() {
347 if discriminant_int_val_set.contains(val) {
348 die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
349 }
350 }
351 }
352 }
353
354 // Reconstruct the alternative_values vec of Expr but sorted.
355 flattened_raw_alternative_values = sorted_alternate_int_values
356 .iter()
357 .map(|val| literal(val.to_owned()))
358 .collect();
359
360 // Add the alternative values to the the set to keep track.
361 discriminant_int_val_set.extend(sorted_alternate_int_values);
362 }
363
364 // Add the current discriminant to the the set to keep track.
365 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
366 discriminant_int_val_set.insert(canonical_value_int);
367 }
368
369 variants.push(VariantInfo {
370 ident,
371 is_default,
372 is_catch_all,
373 canonical_value: discriminant,
374 alternative_values: flattened_raw_alternative_values,
375 });
376
377 // Get the next value for the discriminant.
378 next_discriminant = match discriminant_value {
379 DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)),
380 DiscriminantValue::Expr(expr) => {
381 parse_quote! {
382 #repr::wrapping_add(#expr, 1)
383 }
384 }
385 }
386 }
387
388 let error_type_info = maybe_error_type.unwrap_or_else(|| {
389 let crate_name = Ident::new(&get_crate_name(), Span::call_site());
390 ErrorType {
391 name: parse_quote! {
392 ::#crate_name::TryFromPrimitiveError<Self>
393 },
394 constructor: parse_quote! {
395 ::#crate_name::TryFromPrimitiveError::<Self>::new
396 },
397 }
398 });
399
400 EnumInfo {
401 name,
402 repr,
403 variants,
404 error_type_info,
405 }
406 })
407 }
408 }
409
literal(i: i128) -> Expr410 fn literal(i: i128) -> Expr {
411 Expr::Lit(ExprLit {
412 lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())),
413 attrs: vec![],
414 })
415 }
416
417 enum DiscriminantValue {
418 Literal(i128),
419 Expr(Expr),
420 }
421
parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue>422 fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
423 let mut sign = 1;
424 let mut unsigned_expr = val_exp;
425 if let Expr::Unary(ExprUnary {
426 op: UnOp::Neg(..),
427 expr,
428 ..
429 }) = val_exp
430 {
431 unsigned_expr = expr;
432 sign = -1;
433 }
434 if let Expr::Lit(ExprLit {
435 lit: Lit::Int(ref lit_int),
436 ..
437 }) = unsigned_expr
438 {
439 Ok(DiscriminantValue::Literal(
440 sign * lit_int.base10_parse::<i128>()?,
441 ))
442 } else {
443 Ok(DiscriminantValue::Expr(val_exp.clone()))
444 }
445 }
446
447 #[cfg(feature = "complex-expressions")]
parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>>448 fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
449 fn range_expr_value_to_number(
450 parent_range_expr: &Expr,
451 range_bound_value: &Option<Box<Expr>>,
452 ) -> Result<i128> {
453 // Avoid needing to calculate what the lower and upper bound would be - these are type dependent,
454 // and also may not be obvious in context (e.g. an omitted bound could reasonably mean "from the last discriminant" or "from the lower bound of the type").
455 if let Some(range_bound_value) = range_bound_value {
456 let range_bound_value = parse_discriminant(range_bound_value.as_ref())?;
457 // If non-literals are used, we can't expand to the mapped values, so can't write a nice match statement or do exhaustiveness checking.
458 // Require literals instead.
459 if let DiscriminantValue::Literal(value) = range_bound_value {
460 return Ok(value);
461 }
462 }
463 die!(parent_range_expr => "When ranges are used for alternate values, both bounds most be explicitly specified numeric literals")
464 }
465
466 if let Expr::Range(syn::ExprRange {
467 start, end, limits, ..
468 }) = val_expr
469 {
470 let lower = range_expr_value_to_number(val_expr, start)?;
471 let upper = range_expr_value_to_number(val_expr, end)?;
472 // While this is technically allowed in Rust, and results in an empty range, it's almost certainly a mistake in this context.
473 if lower > upper {
474 die!(val_expr => "When using ranges for alternate values, upper bound must not be less than lower bound");
475 }
476 let mut values = Vec::with_capacity((upper - lower) as usize);
477 let mut next = lower;
478 loop {
479 match limits {
480 syn::RangeLimits::HalfOpen(..) => {
481 if next == upper {
482 break;
483 }
484 }
485 syn::RangeLimits::Closed(..) => {
486 if next > upper {
487 break;
488 }
489 }
490 }
491 values.push(DiscriminantValue::Literal(next));
492 next += 1;
493 }
494 return Ok(values);
495 }
496 parse_discriminant(val_expr).map(|v| vec![v])
497 }
498
499 #[cfg(not(feature = "complex-expressions"))]
parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>>500 fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
501 parse_discriminant(val_expr).map(|v| vec![v])
502 }
503
504 pub(crate) struct VariantInfo {
505 ident: Ident,
506 is_default: bool,
507 is_catch_all: bool,
508 canonical_value: Expr,
509 alternative_values: Vec<Expr>,
510 }
511
512 impl VariantInfo {
all_values(&self) -> impl Iterator<Item = &Expr>513 fn all_values(&self) -> impl Iterator<Item = &Expr> {
514 ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter())
515 }
516 }
517
518 pub(crate) struct ErrorType {
519 pub(crate) name: Path,
520 pub(crate) constructor: Path,
521 }
522
523 impl From<ErrorTypeAttribute> for ErrorType {
from(attribute: ErrorTypeAttribute) -> Self524 fn from(attribute: ErrorTypeAttribute) -> Self {
525 Self {
526 name: attribute.name.path,
527 constructor: attribute.constructor.path,
528 }
529 }
530 }
531
532 #[cfg(feature = "proc-macro-crate")]
get_crate_name() -> String533 pub(crate) fn get_crate_name() -> String {
534 let found_crate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err| {
535 eprintln!("Warning: {}\n => defaulting to `num_enum`", err,);
536 proc_macro_crate::FoundCrate::Itself
537 });
538
539 match found_crate {
540 proc_macro_crate::FoundCrate::Itself => String::from("num_enum"),
541 proc_macro_crate::FoundCrate::Name(name) => name,
542 }
543 }
544
545 // Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency
546 // on serde with std.
547 //
548 // no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry.
549 //
550 // See https://github.com/illicitonion/num_enum/issues/18
551 #[cfg(not(feature = "proc-macro-crate"))]
get_crate_name() -> String552 pub(crate) fn get_crate_name() -> String {
553 String::from("num_enum")
554 }
555