xref: /aosp_15_r20/external/crosvm/argh_helpers/src/lib.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::fmt::Write;
6 
7 use quote::quote;
8 
9 /// A helper derive proc macro to flatten multiple subcommand enums into one
10 /// Note that it is unable to check for duplicate commands and they will be
11 /// tried in order of declaration
12 #[proc_macro_derive(FlattenSubcommand)]
flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream13 pub fn flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
14     let ast = syn::parse_macro_input!(input as syn::DeriveInput);
15     let de = match ast.data {
16         syn::Data::Enum(v) => v,
17         _ => unreachable!(),
18     };
19     let name = &ast.ident;
20 
21     // An enum variant like `<name>(<ty>)`
22     struct SubCommandVariant<'a> {
23         name: &'a syn::Ident,
24         ty: &'a syn::Type,
25     }
26 
27     let variants: Vec<SubCommandVariant<'_>> = de
28         .variants
29         .iter()
30         .map(|variant| {
31             let name = &variant.ident;
32             let ty = match &variant.fields {
33                 syn::Fields::Unnamed(field) => {
34                     if field.unnamed.len() != 1 {
35                         unreachable!()
36                     }
37 
38                     &field.unnamed.first().unwrap().ty
39                 }
40                 _ => unreachable!(),
41             };
42             SubCommandVariant { name, ty }
43         })
44         .collect();
45 
46     let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
47     let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
48 
49     (quote! {
50         impl argh::FromArgs for #name {
51             fn from_args(command_name: &[&str], args: &[&str])
52                 -> std::result::Result<Self, argh::EarlyExit>
53             {
54                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
55                     *subcommand_name
56                 } else {
57                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
58                 };
59 
60                 #(
61                     if <#variant_ty as argh::SubCommands>::COMMANDS
62                     .iter()
63                     .find(|ci| ci.name.eq(subcommand_name))
64                     .is_some()
65                     {
66                         return <#variant_ty as argh::FromArgs>::from_args(command_name, args)
67                             .map(|v| Self::#variant_names(v));
68                     }
69                 )*
70 
71                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
72             }
73 
74             fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
75                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
76                     *subcommand_name
77                 } else {
78                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
79                 };
80 
81                 #(
82                     if <#variant_ty as argh::SubCommands>::COMMANDS
83                     .iter()
84                     .find(|ci| ci.name.eq(subcommand_name))
85                     .is_some()
86                     {
87                         return <#variant_ty as argh::FromArgs>::redact_arg_values(
88                             command_name,
89                             args,
90                         );
91                     }
92 
93                 )*
94 
95                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
96             }
97         }
98 
99         impl argh::SubCommands for #name {
100             const COMMANDS: &'static [&'static argh::CommandInfo] = {
101                 const TOTAL_LEN: usize = #(<#variant_ty as argh::SubCommands>::COMMANDS.len())+*;
102                 const COMMANDS: [&'static argh::CommandInfo; TOTAL_LEN] = {
103                     let slices = &[#(<#variant_ty as argh::SubCommands>::COMMANDS,)*];
104                     // Its not possible for slices[0][0] to be invalid
105                     let mut output = [slices[0][0]; TOTAL_LEN];
106 
107                     let mut output_index = 0;
108                     let mut which_slice = 0;
109                     while which_slice < slices.len() {
110                         let slice = &slices[which_slice];
111                         let mut index_in_slice = 0;
112                         while index_in_slice < slice.len() {
113                             output[output_index] = slice[index_in_slice];
114                             output_index += 1;
115                             index_in_slice += 1;
116                         }
117                         which_slice += 1;
118                     }
119                     output
120                 };
121                 &COMMANDS
122             };
123         }
124     })
125     .into()
126 }
127 
128 /// A helper proc macro to pad strings so that argh would break them at intended points
129 #[proc_macro_attribute]
pad_description_for_argh( _attr: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream130 pub fn pad_description_for_argh(
131     _attr: proc_macro::TokenStream,
132     item: proc_macro::TokenStream,
133 ) -> proc_macro::TokenStream {
134     let mut item = syn::parse_macro_input!(item as syn::Item);
135     if let syn::Item::Struct(s) = &mut item {
136         if let syn::Fields::Named(fields) = &mut s.fields {
137             for f in fields.named.iter_mut() {
138                 for a in f.attrs.iter_mut() {
139                     if a.path()
140                         .get_ident()
141                         .map(|i| i.to_string())
142                         .unwrap_or_default()
143                         == *"doc"
144                     {
145                         if let syn::Meta::NameValue(syn::MetaNameValue {
146                             value:
147                                 syn::Expr::Lit(syn::ExprLit {
148                                     lit: syn::Lit::Str(s),
149                                     ..
150                                 }),
151                             ..
152                         }) = &a.meta
153                         {
154                             let doc = s.value().lines().fold(String::new(), |mut output, s| {
155                                 let _ = write!(output, "{: <61}", s);
156                                 output
157                             });
158                             *a = syn::parse_quote! { #[doc= #doc] };
159                         }
160                     }
161                 }
162             }
163         } else {
164             unreachable!()
165         }
166     } else {
167         unreachable!()
168     }
169     quote! {
170         #item
171     }
172     .into()
173 }
174