1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::fmt::Write;
6
7 use quote::quote;
8
9 /// A helper derive proc macro to flatten multiple subcommand enums into one
10 /// Note that it is unable to check for duplicate commands and they will be
11 /// tried in order of declaration
12 #[proc_macro_derive(FlattenSubcommand)]
flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream13 pub fn flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
14 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
15 let de = match ast.data {
16 syn::Data::Enum(v) => v,
17 _ => unreachable!(),
18 };
19 let name = &ast.ident;
20
21 // An enum variant like `<name>(<ty>)`
22 struct SubCommandVariant<'a> {
23 name: &'a syn::Ident,
24 ty: &'a syn::Type,
25 }
26
27 let variants: Vec<SubCommandVariant<'_>> = de
28 .variants
29 .iter()
30 .map(|variant| {
31 let name = &variant.ident;
32 let ty = match &variant.fields {
33 syn::Fields::Unnamed(field) => {
34 if field.unnamed.len() != 1 {
35 unreachable!()
36 }
37
38 &field.unnamed.first().unwrap().ty
39 }
40 _ => unreachable!(),
41 };
42 SubCommandVariant { name, ty }
43 })
44 .collect();
45
46 let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
47 let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
48
49 (quote! {
50 impl argh::FromArgs for #name {
51 fn from_args(command_name: &[&str], args: &[&str])
52 -> std::result::Result<Self, argh::EarlyExit>
53 {
54 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
55 *subcommand_name
56 } else {
57 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
58 };
59
60 #(
61 if <#variant_ty as argh::SubCommands>::COMMANDS
62 .iter()
63 .find(|ci| ci.name.eq(subcommand_name))
64 .is_some()
65 {
66 return <#variant_ty as argh::FromArgs>::from_args(command_name, args)
67 .map(|v| Self::#variant_names(v));
68 }
69 )*
70
71 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
72 }
73
74 fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
75 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
76 *subcommand_name
77 } else {
78 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
79 };
80
81 #(
82 if <#variant_ty as argh::SubCommands>::COMMANDS
83 .iter()
84 .find(|ci| ci.name.eq(subcommand_name))
85 .is_some()
86 {
87 return <#variant_ty as argh::FromArgs>::redact_arg_values(
88 command_name,
89 args,
90 );
91 }
92
93 )*
94
95 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
96 }
97 }
98
99 impl argh::SubCommands for #name {
100 const COMMANDS: &'static [&'static argh::CommandInfo] = {
101 const TOTAL_LEN: usize = #(<#variant_ty as argh::SubCommands>::COMMANDS.len())+*;
102 const COMMANDS: [&'static argh::CommandInfo; TOTAL_LEN] = {
103 let slices = &[#(<#variant_ty as argh::SubCommands>::COMMANDS,)*];
104 // Its not possible for slices[0][0] to be invalid
105 let mut output = [slices[0][0]; TOTAL_LEN];
106
107 let mut output_index = 0;
108 let mut which_slice = 0;
109 while which_slice < slices.len() {
110 let slice = &slices[which_slice];
111 let mut index_in_slice = 0;
112 while index_in_slice < slice.len() {
113 output[output_index] = slice[index_in_slice];
114 output_index += 1;
115 index_in_slice += 1;
116 }
117 which_slice += 1;
118 }
119 output
120 };
121 &COMMANDS
122 };
123 }
124 })
125 .into()
126 }
127
128 /// A helper proc macro to pad strings so that argh would break them at intended points
129 #[proc_macro_attribute]
pad_description_for_argh( _attr: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream130 pub fn pad_description_for_argh(
131 _attr: proc_macro::TokenStream,
132 item: proc_macro::TokenStream,
133 ) -> proc_macro::TokenStream {
134 let mut item = syn::parse_macro_input!(item as syn::Item);
135 if let syn::Item::Struct(s) = &mut item {
136 if let syn::Fields::Named(fields) = &mut s.fields {
137 for f in fields.named.iter_mut() {
138 for a in f.attrs.iter_mut() {
139 if a.path()
140 .get_ident()
141 .map(|i| i.to_string())
142 .unwrap_or_default()
143 == *"doc"
144 {
145 if let syn::Meta::NameValue(syn::MetaNameValue {
146 value:
147 syn::Expr::Lit(syn::ExprLit {
148 lit: syn::Lit::Str(s),
149 ..
150 }),
151 ..
152 }) = &a.meta
153 {
154 let doc = s.value().lines().fold(String::new(), |mut output, s| {
155 let _ = write!(output, "{: <61}", s);
156 output
157 });
158 *a = syn::parse_quote! { #[doc= #doc] };
159 }
160 }
161 }
162 }
163 } else {
164 unreachable!()
165 }
166 } else {
167 unreachable!()
168 }
169 quote! {
170 #item
171 }
172 .into()
173 }
174