1 // Not supported by MSRV
2 #![allow(clippy::uninlined_format_args)]
3 
4 extern crate proc_macro;
5 
6 use proc_macro::TokenStream;
7 use proc_macro2::Span;
8 use quote::quote;
9 use syn::{parse_macro_input, Expr, Ident};
10 
11 mod enum_attributes;
12 mod parsing;
13 use parsing::{get_crate_name, EnumInfo};
14 mod utils;
15 mod variant_attributes;
16 
17 /// Implements `Into<Primitive>` for a `#[repr(Primitive)] enum`.
18 ///
19 /// (It actually implements `From<Enum> for Primitive`)
20 ///
21 /// ## Allows turning an enum into a primitive.
22 ///
23 /// ```rust
24 /// use num_enum::IntoPrimitive;
25 ///
26 /// #[derive(IntoPrimitive)]
27 /// #[repr(u8)]
28 /// enum Number {
29 ///     Zero,
30 ///     One,
31 /// }
32 ///
33 /// let zero: u8 = Number::Zero.into();
34 /// assert_eq!(zero, 0u8);
35 /// ```
36 #[proc_macro_derive(IntoPrimitive, attributes(num_enum, catch_all))]
derive_into_primitive(input: TokenStream) -> TokenStream37 pub fn derive_into_primitive(input: TokenStream) -> TokenStream {
38     let enum_info = parse_macro_input!(input as EnumInfo);
39     let catch_all = enum_info.catch_all();
40     let name = &enum_info.name;
41     let repr = &enum_info.repr;
42 
43     let body = if let Some(catch_all_ident) = catch_all {
44         quote! {
45             match enum_value {
46                 #name::#catch_all_ident(raw) => raw,
47                 rest => unsafe { *(&rest as *const #name as *const Self) }
48             }
49         }
50     } else {
51         quote! { enum_value as Self }
52     };
53 
54     TokenStream::from(quote! {
55         impl From<#name> for #repr {
56             #[inline]
57             fn from (enum_value: #name) -> Self
58             {
59                 #body
60             }
61         }
62     })
63 }
64 
65 /// Implements `From<Primitive>` for a `#[repr(Primitive)] enum`.
66 ///
67 /// Turning a primitive into an enum with `from`.
68 /// ----------------------------------------------
69 ///
70 /// ```rust
71 /// use num_enum::FromPrimitive;
72 ///
73 /// #[derive(Debug, Eq, PartialEq, FromPrimitive)]
74 /// #[repr(u8)]
75 /// enum Number {
76 ///     Zero,
77 ///     #[num_enum(default)]
78 ///     NonZero,
79 /// }
80 ///
81 /// let zero = Number::from(0u8);
82 /// assert_eq!(zero, Number::Zero);
83 ///
84 /// let one = Number::from(1u8);
85 /// assert_eq!(one, Number::NonZero);
86 ///
87 /// let two = Number::from(2u8);
88 /// assert_eq!(two, Number::NonZero);
89 /// ```
90 #[proc_macro_derive(FromPrimitive, attributes(num_enum, default, catch_all))]
derive_from_primitive(input: TokenStream) -> TokenStream91 pub fn derive_from_primitive(input: TokenStream) -> TokenStream {
92     let enum_info: EnumInfo = parse_macro_input!(input);
93     let krate = Ident::new(&get_crate_name(), Span::call_site());
94 
95     let is_naturally_exhaustive = enum_info.is_naturally_exhaustive();
96     let catch_all_body = match is_naturally_exhaustive {
97         Ok(is_naturally_exhaustive) => {
98             if is_naturally_exhaustive {
99                 quote! { unreachable!("exhaustive enum") }
100             } else if let Some(default_ident) = enum_info.default() {
101                 quote! { Self::#default_ident }
102             } else if let Some(catch_all_ident) = enum_info.catch_all() {
103                 quote! { Self::#catch_all_ident(number) }
104             } else {
105                 let span = Span::call_site();
106                 let message =
107                     "#[derive(num_enum::FromPrimitive)] requires enum to be exhaustive, or a variant marked with `#[default]`, `#[num_enum(default)]`, or `#[num_enum(catch_all)`";
108                 return syn::Error::new(span, message).to_compile_error().into();
109             }
110         }
111         Err(err) => {
112             return err.to_compile_error().into();
113         }
114     };
115 
116     let EnumInfo {
117         ref name, ref repr, ..
118     } = enum_info;
119 
120     let variant_idents: Vec<Ident> = enum_info.variant_idents();
121     let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
122     let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
123 
124     debug_assert_eq!(variant_idents.len(), variant_expressions.len());
125 
126     TokenStream::from(quote! {
127         impl ::#krate::FromPrimitive for #name {
128             type Primitive = #repr;
129 
130             fn from_primitive(number: Self::Primitive) -> Self {
131                 // Use intermediate const(s) so that enums defined like
132                 // `Two = ONE + 1u8` work properly.
133                 #![allow(non_upper_case_globals)]
134                 #(
135                     #(
136                         const #expression_idents: #repr = #variant_expressions;
137                     )*
138                 )*
139                 #[deny(unreachable_patterns)]
140                 match number {
141                     #(
142                         #( #expression_idents )|*
143                         => Self::#variant_idents,
144                     )*
145                     #[allow(unreachable_patterns)]
146                     _ => #catch_all_body,
147                 }
148             }
149         }
150 
151         impl ::core::convert::From<#repr> for #name {
152             #[inline]
153             fn from (
154                 number: #repr,
155             ) -> Self {
156                 ::#krate::FromPrimitive::from_primitive(number)
157             }
158         }
159 
160         #[doc(hidden)]
161         impl ::#krate::CannotDeriveBothFromPrimitiveAndTryFromPrimitive for #name {}
162     })
163 }
164 
165 /// Implements `TryFrom<Primitive>` for a `#[repr(Primitive)] enum`.
166 ///
167 /// Attempting to turn a primitive into an enum with `try_from`.
168 /// ----------------------------------------------
169 ///
170 /// ```rust
171 /// use num_enum::TryFromPrimitive;
172 /// use std::convert::TryFrom;
173 ///
174 /// #[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
175 /// #[repr(u8)]
176 /// enum Number {
177 ///     Zero,
178 ///     One,
179 /// }
180 ///
181 /// let zero = Number::try_from(0u8);
182 /// assert_eq!(zero, Ok(Number::Zero));
183 ///
184 /// let three = Number::try_from(3u8);
185 /// assert_eq!(
186 ///     three.unwrap_err().to_string(),
187 ///     "No discriminant in enum `Number` matches the value `3`",
188 /// );
189 /// ```
190 #[proc_macro_derive(TryFromPrimitive, attributes(num_enum))]
derive_try_from_primitive(input: TokenStream) -> TokenStream191 pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream {
192     let enum_info: EnumInfo = parse_macro_input!(input);
193     let krate = Ident::new(&get_crate_name(), Span::call_site());
194 
195     let EnumInfo {
196         ref name,
197         ref repr,
198         ref error_type_info,
199         ..
200     } = enum_info;
201 
202     let variant_idents: Vec<Ident> = enum_info.variant_idents();
203     let expression_idents: Vec<Vec<Ident>> = enum_info.expression_idents();
204     let variant_expressions: Vec<Vec<Expr>> = enum_info.variant_expressions();
205 
206     debug_assert_eq!(variant_idents.len(), variant_expressions.len());
207 
208     let error_type = &error_type_info.name;
209     let error_constructor = &error_type_info.constructor;
210 
211     TokenStream::from(quote! {
212         impl ::#krate::TryFromPrimitive for #name {
213             type Primitive = #repr;
214             type Error = #error_type;
215 
216             const NAME: &'static str = stringify!(#name);
217 
218             fn try_from_primitive (
219                 number: Self::Primitive,
220             ) -> ::core::result::Result<
221                 Self,
222                 #error_type
223             > {
224                 // Use intermediate const(s) so that enums defined like
225                 // `Two = ONE + 1u8` work properly.
226                 #![allow(non_upper_case_globals)]
227                 #(
228                     #(
229                         const #expression_idents: #repr = #variant_expressions;
230                     )*
231                 )*
232                 #[deny(unreachable_patterns)]
233                 match number {
234                     #(
235                         #( #expression_idents )|*
236                         => ::core::result::Result::Ok(Self::#variant_idents),
237                     )*
238                     #[allow(unreachable_patterns)]
239                     _ => ::core::result::Result::Err(
240                         #error_constructor ( number )
241                     ),
242                 }
243             }
244         }
245 
246         impl ::core::convert::TryFrom<#repr> for #name {
247             type Error = #error_type;
248 
249             #[inline]
250             fn try_from (
251                 number: #repr,
252             ) -> ::core::result::Result<Self, #error_type>
253             {
254                 ::#krate::TryFromPrimitive::try_from_primitive(number)
255             }
256         }
257 
258         #[doc(hidden)]
259         impl ::#krate::CannotDeriveBothFromPrimitiveAndTryFromPrimitive for #name {}
260     })
261 }
262 
263 /// Generates a `unsafe fn unchecked_transmute_from(number: Primitive) -> Self`
264 /// associated function.
265 ///
266 /// Allows unsafely turning a primitive into an enum with unchecked_transmute_from
267 /// ------------------------------------------------------------------------------
268 ///
269 /// If you're really certain a conversion will succeed, and want to avoid a small amount of overhead, you can use unsafe
270 /// code to do this conversion. Unless you have data showing that the match statement generated in the `try_from` above is a
271 /// bottleneck for you, you should avoid doing this, as the unsafe code has potential to cause serious memory issues in
272 /// your program.
273 ///
274 /// Note that this derive ignores any `default`, `catch_all`, and `alternatives` attributes on the enum.
275 /// If you need support for conversions from these values, you should use `TryFromPrimitive` or `FromPrimitive`.
276 ///
277 /// ```rust
278 /// use num_enum::UnsafeFromPrimitive;
279 ///
280 /// #[derive(Debug, Eq, PartialEq, UnsafeFromPrimitive)]
281 /// #[repr(u8)]
282 /// enum Number {
283 ///     Zero,
284 ///     One,
285 /// }
286 ///
287 /// fn main() {
288 ///     assert_eq!(
289 ///         Number::Zero,
290 ///         unsafe { Number::unchecked_transmute_from(0_u8) },
291 ///     );
292 ///     assert_eq!(
293 ///         Number::One,
294 ///         unsafe { Number::unchecked_transmute_from(1_u8) },
295 ///     );
296 /// }
297 ///
298 /// unsafe fn undefined_behavior() {
299 ///     let _ = Number::unchecked_transmute_from(2); // 2 is not a valid discriminant!
300 /// }
301 /// ```
302 #[proc_macro_derive(UnsafeFromPrimitive, attributes(num_enum))]
derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream303 pub fn derive_unsafe_from_primitive(stream: TokenStream) -> TokenStream {
304     let enum_info = parse_macro_input!(stream as EnumInfo);
305     let krate = Ident::new(&get_crate_name(), Span::call_site());
306 
307     let EnumInfo {
308         ref name, ref repr, ..
309     } = enum_info;
310 
311     TokenStream::from(quote! {
312         impl ::#krate::UnsafeFromPrimitive for #name {
313             type Primitive = #repr;
314 
315             unsafe fn unchecked_transmute_from(number: Self::Primitive) -> Self {
316                 ::core::mem::transmute(number)
317             }
318         }
319     })
320 }
321 
322 /// Implements `core::default::Default` for a `#[repr(Primitive)] enum`.
323 ///
324 /// Whichever variant has the `#[default]` or `#[num_enum(default)]` attribute will be returned.
325 /// ----------------------------------------------
326 ///
327 /// ```rust
328 /// #[derive(Debug, Eq, PartialEq, num_enum::Default)]
329 /// #[repr(u8)]
330 /// enum Number {
331 ///     Zero,
332 ///     #[default]
333 ///     One,
334 /// }
335 ///
336 /// assert_eq!(Number::One, Number::default());
337 /// assert_eq!(Number::One, <Number as ::core::default::Default>::default());
338 /// ```
339 #[proc_macro_derive(Default, attributes(num_enum, default))]
derive_default(stream: TokenStream) -> TokenStream340 pub fn derive_default(stream: TokenStream) -> TokenStream {
341     let enum_info = parse_macro_input!(stream as EnumInfo);
342 
343     let default_ident = match enum_info.default() {
344         Some(ident) => ident,
345         None => {
346             let span = Span::call_site();
347             let message =
348                 "#[derive(num_enum::Default)] requires enum to be exhaustive, or a variant marked with `#[default]` or `#[num_enum(default)]`";
349             return syn::Error::new(span, message).to_compile_error().into();
350         }
351     };
352 
353     let EnumInfo { ref name, .. } = enum_info;
354 
355     TokenStream::from(quote! {
356         impl ::core::default::Default for #name {
357             #[inline]
358             fn default() -> Self {
359                 Self::#default_ident
360             }
361         }
362     })
363 }
364