xref: /aosp_15_r20/external/gsc-utils/rust/enum_utils/src/lib.rs (revision 4f2df630800bdcf1d4f0decf95d8a1cb87344f5f)
1 // Copyright 2023 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 use std::cmp::Ordering;
5 
6 use proc_macro::*;
7 
get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)>8 fn get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)> {
9     let mut enum_name = None;
10     let mut cap_next = false;
11     for tree in stream {
12         match tree {
13             TokenTree::Ident(i) if i.to_string() == "enum" => cap_next = true,
14             TokenTree::Ident(i) if cap_next => {
15                 enum_name = Some(TokenTree::Ident(i));
16                 cap_next = false;
17             }
18             TokenTree::Group(g) => {
19                 if let Some(name) = enum_name {
20                     return Some((name, g.stream()));
21                 } else {
22                     let result = get_enum_and_stream(g.stream());
23                     if result.is_some() {
24                         return result;
25                     }
26                 };
27             }
28             _ => {}
29         }
30     }
31     None
32 }
33 
generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream34 fn generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream {
35     let (enum_name, enum_stream) = get_enum_and_stream(item.clone()).expect("Must use on enum");
36     // Check that enums do not have associated fields, but see still want to all doc comments.
37     for tree in enum_stream.clone() {
38         match tree {
39             TokenTree::Group(group) => match group.stream().into_iter().next() {
40                 Some(TokenTree::Ident(ident)) if ident.to_string() == "doc" => {
41                     // This is a doc comment; this is the only allowed group right now since we
42                     // do not support associated fields on any enum value
43                 }
44                 _ => panic!("Only enums without associated fields are supported"),
45             },
46             _ => {
47                 // If no groups, then this enum should just be a normal "C-Style" enum.
48             }
49         }
50     }
51 
52     let qualified_list = enum_stream
53         .into_iter()
54         .map(|tree| {
55             if let TokenTree::Ident(det_name) = tree {
56                 format!("{}::{},", enum_name, det_name).parse().unwrap()
57             } else {
58                 TokenStream::new()
59             }
60         })
61         .collect::<TokenStream>();
62     let array_stream: TokenStream = format!(
63         r#"{header}
64         pub const {name}: &[{enum_name}] = &[{qualified_list}];"#,
65         header = header,
66         name = name,
67         enum_name = enum_name,
68         qualified_list = qualified_list,
69     )
70     .parse()
71     .unwrap();
72 
73     // Emitted exactly what was written, then add the test array
74     item.extend(array_stream);
75     item
76 }
77 
78 /// Generates an test cfg array with the specified name containing all enum values. This is only
79 /// valid for enums where all the variants do not have fields associated with them.
80 /// Access to array is possible only in test builds
81 ///
82 /// # Example
83 ///
84 /// ```
85 /// # #[macro_use] extern crate enum_utils;
86 /// #[gen_test_enum_array(MyTestArrayName)]
87 /// enum MyEnum {
88 ///     ValueOne,
89 ///     ValueTwo,
90 /// }
91 ///
92 /// #[cfg(test)]
93 /// mod tests {
94 ///     #[test]
95 ///     fn test_two_values() {
96 ///         assert_eq!(MyTestArrayName.len(), 2);
97 ///     }
98 /// }
99 /// ```
100 #[proc_macro_attribute]
gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream101 pub fn gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream {
102     generate_enum_array(name, item, "#[cfg(test)]")
103 }
104 
105 /// Generates an array with the specified name containing all enum values. This is only valid for
106 /// enums where all the variants do not have fields associated with them.
107 ///
108 /// # Example
109 ///
110 /// ```
111 /// # #[macro_use] extern crate enum_utils;
112 /// #[gen_enum_array(MyArrayName)]
113 /// pub enum MyEnum {
114 ///     ValueOne,
115 ///     ValueTwo,
116 /// }
117 ///
118 /// fn check() {
119 ///         assert_eq!(MyArrayName.len(), 2);
120 /// }
121 /// ```
122 #[proc_macro_attribute]
gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream123 pub fn gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream {
124     generate_enum_array(name, item, "")
125 }
126 
127 /// Generates an impl to_string for enum which returns &str.
128 /// Method is implemended using match for every enum variant.
129 /// Valid for enums where all the variants do not have fields associated with them.
130 ///
131 /// # Example
132 ///
133 /// ```
134 /// # #[macro_use] extern crate enum_utils;
135 /// #[gen_to_string]
136 /// enum MyEnum {
137 ///     ValueOne,
138 ///     ValueTwo,
139 /// }
140 ///
141 /// fn main() {
142 ///     let e = MyEnum::ValueOne;
143 ///     println!("{}", e.to_string());
144 /// }
145 ///
146 /// ```
147 #[proc_macro_attribute]
gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream148 pub fn gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream {
149     let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum");
150 
151     let mut match_arms = TokenStream::new();
152     let enums_items = enum_stream.into_iter().filter_map(|tt| match tt {
153         TokenTree::Ident(id) => Some(id.to_string()),
154         _ => None,
155     });
156     for item in enums_items {
157         let arm: TokenStream = format!(
158             r#"{enum_name}::{item} => "{item}","#,
159             enum_name = enum_name,
160             item = item,
161         )
162         .parse()
163         .unwrap();
164         match_arms.extend(arm);
165     }
166 
167     let implementation: TokenStream = format!(
168         r#"impl {enum_name} {{
169         pub fn to_string(&self) -> &'static str {{
170             match *self {{
171                 {match_arms}
172             }}
173         }}
174     }}"#,
175         enum_name = enum_name,
176         match_arms = match_arms
177     )
178     .parse()
179     .unwrap();
180 
181     // Emit input as is and add the to_string implementation
182     input.extend(implementation);
183     input
184 }
185 
parse_to_i128(val: &str, negative: bool) -> i128186 fn parse_to_i128(val: &str, negative: bool) -> i128 {
187     let (first_pos, base) = match val.get(0..2) {
188         Some("0x") => (2, 16),
189         Some("0o") => (2, 8),
190         Some("0b") => (2, 2),
191         _ => (0, 10),
192     };
193 
194     let sign = if negative { -1 } else { 1 };
195 
196     // Remove any helper _ in the string literal, then convert from base
197     sign * i128::from_str_radix(
198         &val[first_pos..]
199             .chars()
200             .filter(|c| *c != '_')
201             .collect::<String>(),
202         base,
203     )
204     .unwrap_or_else(|_| panic!("Invalid number {}", val))
205 }
206 
207 /// Generates a exclusive `END` const and `from_<repr>` function for an enum
208 ///
209 /// # Example
210 ///
211 /// ```
212 /// # #[macro_use] extern crate enum_utils;
213 /// #[enum_as(u8)]
214 /// enum MyEnum {
215 ///     ValueZero,
216 ///     ValueOne,
217 /// }
218 ///
219 /// # fn main() {
220 /// assert!(matches!(MyEnum::from_u8(1), Some(MyEnum::ValueOne)));
221 /// assert_eq!(MyEnum::END, 2);
222 /// # }
223 ///
224 /// ```
225 #[proc_macro_attribute]
enum_as(repr: TokenStream, input: TokenStream) -> TokenStream226 pub fn enum_as(repr: TokenStream, input: TokenStream) -> TokenStream {
227     let repr = repr.to_string();
228     let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum");
229 
230     #[derive(Debug, Copy, Clone)]
231     enum WantState {
232         Determinant,
233         Punc,
234         NegativeVal,
235         Val,
236         CommentBlock,
237     }
238     let mut state = WantState::Determinant;
239     let mut skipped_ranges = vec![];
240     let mut start: Option<i128> = None;
241     let mut end: Option<i128> = None;
242     for tt in enum_stream {
243         use WantState::*;
244         state = match (tt, state) {
245             (TokenTree::Ident(_), Determinant) => Punc,
246             // If we are expecting a Determinant, but get a # instead, it must be a comment
247             (TokenTree::Punct(p), Determinant) if p.as_char() == '#' => CommentBlock,
248             (TokenTree::Punct(p), Val) if p.as_char() == '-' => NegativeVal,
249             (TokenTree::Group(_), CommentBlock) => Determinant,
250             (TokenTree::Punct(p), Punc) => match p.as_char() {
251                 ',' => {
252                     start = Some(start.unwrap_or(0));
253                     end = Some(end.unwrap_or_else(|| start.unwrap()) + 1);
254                     Determinant
255                 }
256                 '=' => Val,
257                 other => panic!("Unexpected punctuation '{}'", other),
258             },
259             (TokenTree::Literal(l), Val | NegativeVal) => {
260                 let val = parse_to_i128(&l.to_string(), matches!(state, NegativeVal));
261                 start = Some(start.unwrap_or(val));
262                 let expected = end.unwrap_or_else(|| start.unwrap());
263                 match val.cmp(&expected) {
264                     Ordering::Greater => {
265                         skipped_ranges.push((expected, val));
266                         end = Some(val);
267                     }
268                     Ordering::Less => panic!("Discriminants must increase in value"),
269                     Ordering::Equal => (),
270                 }
271                 Punc
272             }
273             (tt, want) => {
274                 panic!("Want {:?} but got {:?}", want, tt)
275             }
276         };
277     }
278 
279     let skipped_ranges = if skipped_ranges.is_empty() {
280         "".to_string()
281     } else {
282         format!(
283             r#"for r in &{ranges:?} {{
284                 if (r.0..r.1).contains(&val) {{
285                     return None;
286                 }}
287             }}"#,
288             ranges = skipped_ranges
289         )
290     };
291 
292     // Ensure that there is at least one discriminant
293     let start = start.expect("Enum needs at least one discriminant");
294     let end = end.expect("Enum needs at least one discriminant");
295 
296     // Ensure that END will fit into usize or u64
297     u64::try_from(end).unwrap_or_else(|_| panic!("Value after last discriminant must be unsigned"));
298 
299     let implementation: TokenStream = format!(
300         r#"impl {enum_name} {{
301             pub const END: {end_type} = {end};
302 
303             pub fn from_{repr}(val: {repr}) -> Option<Self> {{
304                 if val < {start} {{
305                     return None;
306                 }}
307                 if val > ((Self::END - 1) as {repr}) {{
308                     return None;
309                 }}
310                 {skipped_ranges}
311                 Some( unsafe {{ core::mem::transmute(val) }})
312             }}
313         }}
314 
315         impl PartialEq for {enum_name} {{
316             fn eq(&self, other: &Self) -> bool {{
317                 *self as {repr} == *other as {repr}
318             }}
319         }}
320 
321         impl Eq for {enum_name} {{ }}
322 
323         #[cfg(not(target_arch = "riscv32"))]
324         impl core::hash::Hash for {enum_name} {{
325             fn hash<H: core::hash::Hasher>(&self, state: &mut H) {{
326                 (*self as {repr}).hash(state);
327             }}
328          }}"#,
329         enum_name = enum_name,
330         start = start,
331         end = end,
332         end_type = if repr == "u64" { "u64" } else { "usize " },
333         repr = repr,
334         skipped_ranges = skipped_ranges,
335     )
336     .parse()
337     .unwrap();
338 
339     // Attribute input with a repr, Clone, Clone, and allow(dead_code), then add the custom
340     // implementation. We allow dead code since the these enums are typically interface enums that
341     // define an API boundary.
342     let mut res: TokenStream = format!(
343         "#[repr({})]\n#[derive(Copy, Clone)]\n#[allow(dead_code)]",
344         repr
345     )
346     .parse()
347     .unwrap();
348     res.extend(input);
349     res.extend(implementation);
350     res
351 }
352 
get_param_list_without_self(params: TokenStream) -> String353 fn get_param_list_without_self(params: TokenStream) -> String {
354     let mut result = String::new();
355 
356     #[derive(Debug, Copy, Clone)]
357     enum WantState {
358         SelfParam,
359         FirstIdentifier,
360         Comma,
361     }
362     let mut state = WantState::SelfParam;
363     for tt in params {
364         use WantState::*;
365         state = match (tt, state) {
366             (TokenTree::Ident(ident), SelfParam) if ident.to_string() == "self" => FirstIdentifier,
367             (TokenTree::Ident(ident), FirstIdentifier) => {
368                 result.push_str(&ident.to_string());
369                 Comma
370             }
371             (TokenTree::Punct(p), Comma) if p.to_string() == "," => {
372                 result.push(',');
373                 FirstIdentifier
374             }
375             (_, other) => {
376                 // Do nothing; keep watching for what we are looking for
377                 other
378             }
379         };
380     }
381 
382     result
383 }
384 
385 /// Replaces the function body with a single statement that pass all parameters thru to same
386 /// function name on the variable specified in the passthru_to macro
387 ///
388 /// # Example
389 ///
390 /// ```
391 /// # #[macro_use] extern crate enum_utils;
392 /// pub struct Inner(usize);
393 ///
394 /// impl Inner {
395 ///     pub fn read_plus(&self, plus: usize) -> usize {
396 ///         self.0 + plus
397 ///     }
398 /// }
399 ///
400 /// pub struct Outer {
401 ///     pub my_inner: Inner,
402 /// }
403 ///
404 /// impl Outer {
405 ///     #[passthru_to(my_inner)]
406 ///     pub fn read_plus(&self, plus: usize) -> usize {}
407 /// }
408 ///
409 /// ```
410 #[proc_macro_attribute]
passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream411 pub fn passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream {
412     #[derive(Debug, Copy, Clone)]
413     enum WantState {
414         FunctionKeyword,
415         FunctionName,
416         Parameters,
417         Body,
418         End,
419     }
420     let mut state = WantState::FunctionKeyword;
421     let mut name = None;
422     let mut params = None;
423     let mut body_num = None;
424     for (i, tt) in input.clone().into_iter().enumerate() {
425         use WantState::*;
426         state = match (tt, state) {
427             (TokenTree::Ident(ident), FunctionKeyword) if ident.to_string() == "fn" => FunctionName,
428             (_, FunctionKeyword) => FunctionKeyword,
429             (TokenTree::Ident(ident), FunctionName) => {
430                 name = Some(ident.to_string());
431                 Parameters
432             }
433             (TokenTree::Group(group), Parameters) => {
434                 params = Some(get_param_list_without_self(group.stream()));
435                 Body
436             }
437             (TokenTree::Group(group), Body) if group.delimiter() == Delimiter::Brace => {
438                 body_num = Some(i);
439                 End
440             }
441             (_, Body) => Body,
442             (tt, want) => {
443                 panic!("Want {:?} but got {:?}", want, tt)
444             }
445         };
446     }
447 
448     let implementation: TokenStream = format!(
449         r#"{{ self.{passthru_var}.{name}({params}) }}"#,
450         passthru_var = passthru_var,
451         name = name.expect("Cannot find function name"),
452         params = params.expect("Could find parameters"),
453     )
454     .parse()
455     .unwrap();
456 
457     // Take everything up to body, then replace body with the passthru implementation
458     input
459         .into_iter()
460         .take(body_num.expect("Cannot find body"))
461         .chain(implementation)
462         .collect()
463 }
464