1 // Copyright 2023 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 use std::cmp::Ordering;
5
6 use proc_macro::*;
7
get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)>8 fn get_enum_and_stream(stream: TokenStream) -> Option<(TokenTree, TokenStream)> {
9 let mut enum_name = None;
10 let mut cap_next = false;
11 for tree in stream {
12 match tree {
13 TokenTree::Ident(i) if i.to_string() == "enum" => cap_next = true,
14 TokenTree::Ident(i) if cap_next => {
15 enum_name = Some(TokenTree::Ident(i));
16 cap_next = false;
17 }
18 TokenTree::Group(g) => {
19 if let Some(name) = enum_name {
20 return Some((name, g.stream()));
21 } else {
22 let result = get_enum_and_stream(g.stream());
23 if result.is_some() {
24 return result;
25 }
26 };
27 }
28 _ => {}
29 }
30 }
31 None
32 }
33
generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream34 fn generate_enum_array(name: TokenStream, mut item: TokenStream, header: &str) -> TokenStream {
35 let (enum_name, enum_stream) = get_enum_and_stream(item.clone()).expect("Must use on enum");
36 // Check that enums do not have associated fields, but see still want to all doc comments.
37 for tree in enum_stream.clone() {
38 match tree {
39 TokenTree::Group(group) => match group.stream().into_iter().next() {
40 Some(TokenTree::Ident(ident)) if ident.to_string() == "doc" => {
41 // This is a doc comment; this is the only allowed group right now since we
42 // do not support associated fields on any enum value
43 }
44 _ => panic!("Only enums without associated fields are supported"),
45 },
46 _ => {
47 // If no groups, then this enum should just be a normal "C-Style" enum.
48 }
49 }
50 }
51
52 let qualified_list = enum_stream
53 .into_iter()
54 .map(|tree| {
55 if let TokenTree::Ident(det_name) = tree {
56 format!("{}::{},", enum_name, det_name).parse().unwrap()
57 } else {
58 TokenStream::new()
59 }
60 })
61 .collect::<TokenStream>();
62 let array_stream: TokenStream = format!(
63 r#"{header}
64 pub const {name}: &[{enum_name}] = &[{qualified_list}];"#,
65 header = header,
66 name = name,
67 enum_name = enum_name,
68 qualified_list = qualified_list,
69 )
70 .parse()
71 .unwrap();
72
73 // Emitted exactly what was written, then add the test array
74 item.extend(array_stream);
75 item
76 }
77
78 /// Generates an test cfg array with the specified name containing all enum values. This is only
79 /// valid for enums where all the variants do not have fields associated with them.
80 /// Access to array is possible only in test builds
81 ///
82 /// # Example
83 ///
84 /// ```
85 /// # #[macro_use] extern crate enum_utils;
86 /// #[gen_test_enum_array(MyTestArrayName)]
87 /// enum MyEnum {
88 /// ValueOne,
89 /// ValueTwo,
90 /// }
91 ///
92 /// #[cfg(test)]
93 /// mod tests {
94 /// #[test]
95 /// fn test_two_values() {
96 /// assert_eq!(MyTestArrayName.len(), 2);
97 /// }
98 /// }
99 /// ```
100 #[proc_macro_attribute]
gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream101 pub fn gen_test_enum_array(name: TokenStream, item: TokenStream) -> TokenStream {
102 generate_enum_array(name, item, "#[cfg(test)]")
103 }
104
105 /// Generates an array with the specified name containing all enum values. This is only valid for
106 /// enums where all the variants do not have fields associated with them.
107 ///
108 /// # Example
109 ///
110 /// ```
111 /// # #[macro_use] extern crate enum_utils;
112 /// #[gen_enum_array(MyArrayName)]
113 /// pub enum MyEnum {
114 /// ValueOne,
115 /// ValueTwo,
116 /// }
117 ///
118 /// fn check() {
119 /// assert_eq!(MyArrayName.len(), 2);
120 /// }
121 /// ```
122 #[proc_macro_attribute]
gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream123 pub fn gen_enum_array(name: TokenStream, item: TokenStream) -> TokenStream {
124 generate_enum_array(name, item, "")
125 }
126
127 /// Generates an impl to_string for enum which returns &str.
128 /// Method is implemended using match for every enum variant.
129 /// Valid for enums where all the variants do not have fields associated with them.
130 ///
131 /// # Example
132 ///
133 /// ```
134 /// # #[macro_use] extern crate enum_utils;
135 /// #[gen_to_string]
136 /// enum MyEnum {
137 /// ValueOne,
138 /// ValueTwo,
139 /// }
140 ///
141 /// fn main() {
142 /// let e = MyEnum::ValueOne;
143 /// println!("{}", e.to_string());
144 /// }
145 ///
146 /// ```
147 #[proc_macro_attribute]
gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream148 pub fn gen_to_string(_attr: TokenStream, mut input: TokenStream) -> TokenStream {
149 let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum");
150
151 let mut match_arms = TokenStream::new();
152 let enums_items = enum_stream.into_iter().filter_map(|tt| match tt {
153 TokenTree::Ident(id) => Some(id.to_string()),
154 _ => None,
155 });
156 for item in enums_items {
157 let arm: TokenStream = format!(
158 r#"{enum_name}::{item} => "{item}","#,
159 enum_name = enum_name,
160 item = item,
161 )
162 .parse()
163 .unwrap();
164 match_arms.extend(arm);
165 }
166
167 let implementation: TokenStream = format!(
168 r#"impl {enum_name} {{
169 pub fn to_string(&self) -> &'static str {{
170 match *self {{
171 {match_arms}
172 }}
173 }}
174 }}"#,
175 enum_name = enum_name,
176 match_arms = match_arms
177 )
178 .parse()
179 .unwrap();
180
181 // Emit input as is and add the to_string implementation
182 input.extend(implementation);
183 input
184 }
185
parse_to_i128(val: &str, negative: bool) -> i128186 fn parse_to_i128(val: &str, negative: bool) -> i128 {
187 let (first_pos, base) = match val.get(0..2) {
188 Some("0x") => (2, 16),
189 Some("0o") => (2, 8),
190 Some("0b") => (2, 2),
191 _ => (0, 10),
192 };
193
194 let sign = if negative { -1 } else { 1 };
195
196 // Remove any helper _ in the string literal, then convert from base
197 sign * i128::from_str_radix(
198 &val[first_pos..]
199 .chars()
200 .filter(|c| *c != '_')
201 .collect::<String>(),
202 base,
203 )
204 .unwrap_or_else(|_| panic!("Invalid number {}", val))
205 }
206
207 /// Generates a exclusive `END` const and `from_<repr>` function for an enum
208 ///
209 /// # Example
210 ///
211 /// ```
212 /// # #[macro_use] extern crate enum_utils;
213 /// #[enum_as(u8)]
214 /// enum MyEnum {
215 /// ValueZero,
216 /// ValueOne,
217 /// }
218 ///
219 /// # fn main() {
220 /// assert!(matches!(MyEnum::from_u8(1), Some(MyEnum::ValueOne)));
221 /// assert_eq!(MyEnum::END, 2);
222 /// # }
223 ///
224 /// ```
225 #[proc_macro_attribute]
enum_as(repr: TokenStream, input: TokenStream) -> TokenStream226 pub fn enum_as(repr: TokenStream, input: TokenStream) -> TokenStream {
227 let repr = repr.to_string();
228 let (enum_name, enum_stream) = get_enum_and_stream(input.clone()).expect("Must use on enum");
229
230 #[derive(Debug, Copy, Clone)]
231 enum WantState {
232 Determinant,
233 Punc,
234 NegativeVal,
235 Val,
236 CommentBlock,
237 }
238 let mut state = WantState::Determinant;
239 let mut skipped_ranges = vec![];
240 let mut start: Option<i128> = None;
241 let mut end: Option<i128> = None;
242 for tt in enum_stream {
243 use WantState::*;
244 state = match (tt, state) {
245 (TokenTree::Ident(_), Determinant) => Punc,
246 // If we are expecting a Determinant, but get a # instead, it must be a comment
247 (TokenTree::Punct(p), Determinant) if p.as_char() == '#' => CommentBlock,
248 (TokenTree::Punct(p), Val) if p.as_char() == '-' => NegativeVal,
249 (TokenTree::Group(_), CommentBlock) => Determinant,
250 (TokenTree::Punct(p), Punc) => match p.as_char() {
251 ',' => {
252 start = Some(start.unwrap_or(0));
253 end = Some(end.unwrap_or_else(|| start.unwrap()) + 1);
254 Determinant
255 }
256 '=' => Val,
257 other => panic!("Unexpected punctuation '{}'", other),
258 },
259 (TokenTree::Literal(l), Val | NegativeVal) => {
260 let val = parse_to_i128(&l.to_string(), matches!(state, NegativeVal));
261 start = Some(start.unwrap_or(val));
262 let expected = end.unwrap_or_else(|| start.unwrap());
263 match val.cmp(&expected) {
264 Ordering::Greater => {
265 skipped_ranges.push((expected, val));
266 end = Some(val);
267 }
268 Ordering::Less => panic!("Discriminants must increase in value"),
269 Ordering::Equal => (),
270 }
271 Punc
272 }
273 (tt, want) => {
274 panic!("Want {:?} but got {:?}", want, tt)
275 }
276 };
277 }
278
279 let skipped_ranges = if skipped_ranges.is_empty() {
280 "".to_string()
281 } else {
282 format!(
283 r#"for r in &{ranges:?} {{
284 if (r.0..r.1).contains(&val) {{
285 return None;
286 }}
287 }}"#,
288 ranges = skipped_ranges
289 )
290 };
291
292 // Ensure that there is at least one discriminant
293 let start = start.expect("Enum needs at least one discriminant");
294 let end = end.expect("Enum needs at least one discriminant");
295
296 // Ensure that END will fit into usize or u64
297 u64::try_from(end).unwrap_or_else(|_| panic!("Value after last discriminant must be unsigned"));
298
299 let implementation: TokenStream = format!(
300 r#"impl {enum_name} {{
301 pub const END: {end_type} = {end};
302
303 pub fn from_{repr}(val: {repr}) -> Option<Self> {{
304 if val < {start} {{
305 return None;
306 }}
307 if val > ((Self::END - 1) as {repr}) {{
308 return None;
309 }}
310 {skipped_ranges}
311 Some( unsafe {{ core::mem::transmute(val) }})
312 }}
313 }}
314
315 impl PartialEq for {enum_name} {{
316 fn eq(&self, other: &Self) -> bool {{
317 *self as {repr} == *other as {repr}
318 }}
319 }}
320
321 impl Eq for {enum_name} {{ }}
322
323 #[cfg(not(target_arch = "riscv32"))]
324 impl core::hash::Hash for {enum_name} {{
325 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {{
326 (*self as {repr}).hash(state);
327 }}
328 }}"#,
329 enum_name = enum_name,
330 start = start,
331 end = end,
332 end_type = if repr == "u64" { "u64" } else { "usize " },
333 repr = repr,
334 skipped_ranges = skipped_ranges,
335 )
336 .parse()
337 .unwrap();
338
339 // Attribute input with a repr, Clone, Clone, and allow(dead_code), then add the custom
340 // implementation. We allow dead code since the these enums are typically interface enums that
341 // define an API boundary.
342 let mut res: TokenStream = format!(
343 "#[repr({})]\n#[derive(Copy, Clone)]\n#[allow(dead_code)]",
344 repr
345 )
346 .parse()
347 .unwrap();
348 res.extend(input);
349 res.extend(implementation);
350 res
351 }
352
get_param_list_without_self(params: TokenStream) -> String353 fn get_param_list_without_self(params: TokenStream) -> String {
354 let mut result = String::new();
355
356 #[derive(Debug, Copy, Clone)]
357 enum WantState {
358 SelfParam,
359 FirstIdentifier,
360 Comma,
361 }
362 let mut state = WantState::SelfParam;
363 for tt in params {
364 use WantState::*;
365 state = match (tt, state) {
366 (TokenTree::Ident(ident), SelfParam) if ident.to_string() == "self" => FirstIdentifier,
367 (TokenTree::Ident(ident), FirstIdentifier) => {
368 result.push_str(&ident.to_string());
369 Comma
370 }
371 (TokenTree::Punct(p), Comma) if p.to_string() == "," => {
372 result.push(',');
373 FirstIdentifier
374 }
375 (_, other) => {
376 // Do nothing; keep watching for what we are looking for
377 other
378 }
379 };
380 }
381
382 result
383 }
384
385 /// Replaces the function body with a single statement that pass all parameters thru to same
386 /// function name on the variable specified in the passthru_to macro
387 ///
388 /// # Example
389 ///
390 /// ```
391 /// # #[macro_use] extern crate enum_utils;
392 /// pub struct Inner(usize);
393 ///
394 /// impl Inner {
395 /// pub fn read_plus(&self, plus: usize) -> usize {
396 /// self.0 + plus
397 /// }
398 /// }
399 ///
400 /// pub struct Outer {
401 /// pub my_inner: Inner,
402 /// }
403 ///
404 /// impl Outer {
405 /// #[passthru_to(my_inner)]
406 /// pub fn read_plus(&self, plus: usize) -> usize {}
407 /// }
408 ///
409 /// ```
410 #[proc_macro_attribute]
passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream411 pub fn passthru_to(passthru_var: TokenStream, input: TokenStream) -> TokenStream {
412 #[derive(Debug, Copy, Clone)]
413 enum WantState {
414 FunctionKeyword,
415 FunctionName,
416 Parameters,
417 Body,
418 End,
419 }
420 let mut state = WantState::FunctionKeyword;
421 let mut name = None;
422 let mut params = None;
423 let mut body_num = None;
424 for (i, tt) in input.clone().into_iter().enumerate() {
425 use WantState::*;
426 state = match (tt, state) {
427 (TokenTree::Ident(ident), FunctionKeyword) if ident.to_string() == "fn" => FunctionName,
428 (_, FunctionKeyword) => FunctionKeyword,
429 (TokenTree::Ident(ident), FunctionName) => {
430 name = Some(ident.to_string());
431 Parameters
432 }
433 (TokenTree::Group(group), Parameters) => {
434 params = Some(get_param_list_without_self(group.stream()));
435 Body
436 }
437 (TokenTree::Group(group), Body) if group.delimiter() == Delimiter::Brace => {
438 body_num = Some(i);
439 End
440 }
441 (_, Body) => Body,
442 (tt, want) => {
443 panic!("Want {:?} but got {:?}", want, tt)
444 }
445 };
446 }
447
448 let implementation: TokenStream = format!(
449 r#"{{ self.{passthru_var}.{name}({params}) }}"#,
450 passthru_var = passthru_var,
451 name = name.expect("Cannot find function name"),
452 params = params.expect("Could find parameters"),
453 )
454 .parse()
455 .unwrap();
456
457 // Take everything up to body, then replace body with the passthru implementation
458 input
459 .into_iter()
460 .take(body_num.expect("Cannot find body"))
461 .chain(implementation)
462 .collect()
463 }
464