1 #![recursion_limit = "256"]
2 // Copyright (c) 2020 Google LLC All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 /// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7 ///
8 /// For more thorough documentation, see the `argh` crate itself.
9 extern crate proc_macro;
10 
11 use {
12     crate::{
13         errors::Errors,
14         parse_attrs::{check_long_name, FieldAttrs, FieldKind, TypeAttrs},
15     },
16     proc_macro2::{Span, TokenStream},
17     quote::{quote, quote_spanned, ToTokens},
18     std::{collections::HashMap, str::FromStr},
19     syn::{spanned::Spanned, GenericArgument, LitStr, PathArguments, Type},
20 };
21 
22 mod args_info;
23 mod errors;
24 mod help;
25 mod parse_attrs;
26 
27 /// Entrypoint for `#[derive(FromArgs)]`.
28 #[proc_macro_derive(FromArgs, attributes(argh))]
argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream29 pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30     let ast = syn::parse_macro_input!(input as syn::DeriveInput);
31     let gen = impl_from_args(&ast);
32     gen.into()
33 }
34 
35 /// Entrypoint for `#[derive(ArgsInfo)]`.
36 #[proc_macro_derive(ArgsInfo, attributes(argh))]
args_info_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream37 pub fn args_info_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
38     let ast = syn::parse_macro_input!(input as syn::DeriveInput);
39     let gen = args_info::impl_args_info(&ast);
40     gen.into()
41 }
42 
43 /// Transform the input into a token stream containing any generated implementations,
44 /// as well as all errors that occurred.
impl_from_args(input: &syn::DeriveInput) -> TokenStream45 fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
46     let errors = &Errors::default();
47     let type_attrs = &TypeAttrs::parse(errors, input);
48     let mut output_tokens = match &input.data {
49         syn::Data::Struct(ds) => {
50             impl_from_args_struct(errors, &input.ident, type_attrs, &input.generics, ds)
51         }
52         syn::Data::Enum(de) => {
53             impl_from_args_enum(errors, &input.ident, type_attrs, &input.generics, de)
54         }
55         syn::Data::Union(_) => {
56             errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
57             TokenStream::new()
58         }
59     };
60     errors.to_tokens(&mut output_tokens);
61     output_tokens
62 }
63 
64 /// The kind of optionality a parameter has.
65 enum Optionality {
66     None,
67     Defaulted(TokenStream),
68     Optional,
69     Repeating,
70 }
71 
72 impl PartialEq<Optionality> for Optionality {
eq(&self, other: &Optionality) -> bool73     fn eq(&self, other: &Optionality) -> bool {
74         use Optionality::*;
75         // NB: (Defaulted, Defaulted) can't contain the same token streams
76         matches!((self, other), (Optional, Optional) | (Repeating, Repeating))
77     }
78 }
79 
80 impl Optionality {
81     /// Whether or not this is `Optionality::None`
is_required(&self) -> bool82     fn is_required(&self) -> bool {
83         matches!(self, Optionality::None)
84     }
85 }
86 
87 /// A field of a `#![derive(FromArgs)]` struct with attributes and some other
88 /// notable metadata appended.
89 struct StructField<'a> {
90     /// The original parsed field
91     field: &'a syn::Field,
92     /// The parsed attributes of the field
93     attrs: FieldAttrs,
94     /// The field name. This is contained optionally inside `field`,
95     /// but is duplicated non-optionally here to indicate that all field that
96     /// have reached this point must have a field name, and it no longer
97     /// needs to be unwrapped.
98     name: &'a syn::Ident,
99     /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
100     /// but here is fully present to indicate that we only have to consider fields
101     /// with a valid `kind` at this point.
102     kind: FieldKind,
103     // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
104     // This is used to enable consistent parsing code between optional and non-optional
105     // keyed and subcommand fields.
106     ty_without_wrapper: &'a syn::Type,
107     // Whether the field represents an optional value, such as an `Option` subcommand field
108     // or an `Option` or `Vec` keyed argument, or if it has a `default`.
109     optionality: Optionality,
110     // The `--`-prefixed name of the option, if one exists.
111     long_name: Option<String>,
112 }
113 
114 impl<'a> StructField<'a> {
115     /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
116     /// fields required for code generation.
new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self>117     fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
118         let name = field.ident.as_ref().expect("missing ident for named field");
119 
120         // Ensure that one "kind" is present (switch, option, subcommand, positional)
121         let kind = if let Some(field_type) = &attrs.field_type {
122             field_type.kind
123         } else {
124             errors.err(
125                 field,
126                 concat!(
127                     "Missing `argh` field kind attribute.\n",
128                     "Expected one of: `switch`, `option`, `remaining`, `subcommand`, `positional`",
129                 ),
130             );
131             return None;
132         };
133 
134         // Parse out whether a field is optional (`Option` or `Vec`).
135         let optionality;
136         let ty_without_wrapper;
137         match kind {
138             FieldKind::Switch => {
139                 if !ty_expect_switch(errors, &field.ty) {
140                     return None;
141                 }
142                 optionality = Optionality::Optional;
143                 ty_without_wrapper = &field.ty;
144             }
145             FieldKind::Option | FieldKind::Positional => {
146                 if let Some(default) = &attrs.default {
147                     let tokens = match TokenStream::from_str(&default.value()) {
148                         Ok(tokens) => tokens,
149                         Err(_) => {
150                             errors.err(&default, "Invalid tokens: unable to lex `default` value");
151                             return None;
152                         }
153                     };
154                     // Set the span of the generated tokens to the string literal
155                     let tokens: TokenStream = tokens
156                         .into_iter()
157                         .map(|mut tree| {
158                             tree.set_span(default.span());
159                             tree
160                         })
161                         .collect();
162                     optionality = Optionality::Defaulted(tokens);
163                     ty_without_wrapper = &field.ty;
164                 } else {
165                     let mut inner = None;
166                     optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
167                         inner = Some(x);
168                         Optionality::Optional
169                     } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
170                         inner = Some(x);
171                         Optionality::Repeating
172                     } else {
173                         Optionality::None
174                     };
175                     ty_without_wrapper = inner.unwrap_or(&field.ty);
176                 }
177             }
178             FieldKind::SubCommand => {
179                 let inner = ty_inner(&["Option"], &field.ty);
180                 optionality =
181                     if inner.is_some() { Optionality::Optional } else { Optionality::None };
182                 ty_without_wrapper = inner.unwrap_or(&field.ty);
183             }
184         }
185 
186         // Determine the "long" name of options and switches.
187         // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
188         let long_name = match kind {
189             FieldKind::Switch | FieldKind::Option => {
190                 let long_name = attrs.long.as_ref().map(syn::LitStr::value).unwrap_or_else(|| {
191                     let kebab_name = to_kebab_case(&name.to_string());
192                     check_long_name(errors, name, &kebab_name);
193                     kebab_name
194                 });
195                 if long_name == "help" {
196                     errors.err(field, "Custom `--help` flags are not supported.");
197                 }
198                 let long_name = format!("--{}", long_name);
199                 Some(long_name)
200             }
201             FieldKind::SubCommand | FieldKind::Positional => None,
202         };
203 
204         Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
205     }
206 
positional_arg_name(&self) -> String207     pub(crate) fn positional_arg_name(&self) -> String {
208         self.attrs
209             .arg_name
210             .as_ref()
211             .map(LitStr::value)
212             .unwrap_or_else(|| self.name.to_string().trim_matches('_').to_owned())
213     }
214 }
215 
to_kebab_case(s: &str) -> String216 fn to_kebab_case(s: &str) -> String {
217     let words = s.split('_').filter(|word| !word.is_empty());
218     let mut res = String::with_capacity(s.len());
219     for word in words {
220         if !res.is_empty() {
221             res.push('-')
222         }
223         res.push_str(word)
224     }
225     res
226 }
227 
228 #[test]
test_kebabs()229 fn test_kebabs() {
230     #[track_caller]
231     fn check(s: &str, want: &str) {
232         let got = to_kebab_case(s);
233         assert_eq!(got.as_str(), want)
234     }
235     check("", "");
236     check("_", "");
237     check("foo", "foo");
238     check("__foo_", "foo");
239     check("foo_bar", "foo-bar");
240     check("foo__Bar", "foo-Bar");
241     check("foo_bar__baz_", "foo-bar-baz");
242 }
243 
244 /// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
impl_from_args_struct( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ds: &syn::DataStruct, ) -> TokenStream245 fn impl_from_args_struct(
246     errors: &Errors,
247     name: &syn::Ident,
248     type_attrs: &TypeAttrs,
249     generic_args: &syn::Generics,
250     ds: &syn::DataStruct,
251 ) -> TokenStream {
252     let fields = match &ds.fields {
253         syn::Fields::Named(fields) => fields,
254         syn::Fields::Unnamed(_) => {
255             errors.err(
256                 &ds.struct_token,
257                 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
258             );
259             return TokenStream::new();
260         }
261         syn::Fields::Unit => {
262             errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
263             return TokenStream::new();
264         }
265     };
266 
267     let fields: Vec<_> = fields
268         .named
269         .iter()
270         .filter_map(|field| {
271             let attrs = FieldAttrs::parse(errors, field);
272             StructField::new(errors, field, attrs)
273         })
274         .collect();
275 
276     ensure_unique_names(errors, &fields);
277     ensure_only_last_positional_is_optional(errors, &fields);
278 
279     let impl_span = Span::call_site();
280 
281     let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
282 
283     let redact_arg_values_method =
284         impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
285 
286     let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs, generic_args);
287 
288     let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
289     let trait_impl = quote_spanned! { impl_span =>
290         #[automatically_derived]
291         impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
292             #from_args_method
293 
294             #redact_arg_values_method
295         }
296 
297         #top_or_sub_cmd_impl
298     };
299 
300     trait_impl
301 }
302 
impl_from_args_struct_from_args<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream303 fn impl_from_args_struct_from_args<'a>(
304     errors: &Errors,
305     type_attrs: &TypeAttrs,
306     fields: &'a [StructField<'a>],
307 ) -> TokenStream {
308     let init_fields = declare_local_storage_for_from_args_fields(fields);
309     let unwrap_fields = unwrap_from_args_fields(fields);
310     let positional_fields: Vec<&StructField<'_>> =
311         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
312     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
313     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
314     let last_positional_is_repeating = positional_fields
315         .last()
316         .map(|field| field.optionality == Optionality::Repeating)
317         .unwrap_or(false);
318     let last_positional_is_greedy = positional_fields
319         .last()
320         .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
321         .unwrap_or(false);
322 
323     let flag_output_table = fields.iter().filter_map(|field| {
324         let field_name = &field.field.ident;
325         match field.kind {
326             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
327             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
328             FieldKind::SubCommand | FieldKind::Positional => None,
329         }
330     });
331 
332     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
333 
334     let mut subcommands_iter =
335         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
336 
337     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
338     for dup_subcommand in subcommands_iter {
339         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
340     }
341 
342     let impl_span = Span::call_site();
343 
344     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
345 
346     let append_missing_requirements =
347         append_missing_requirements(&missing_requirements_ident, fields);
348 
349     let parse_subcommands = if let Some(subcommand) = subcommand {
350         let name = subcommand.name;
351         let ty = subcommand.ty_without_wrapper;
352         quote_spanned! { impl_span =>
353             Some(argh::ParseStructSubCommand {
354                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
355                 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
356                 parse_func: &mut |__command, __remaining_args| {
357                     #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
358                     Ok(())
359                 },
360             })
361         }
362     } else {
363         quote_spanned! { impl_span => None }
364     };
365 
366     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
367     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
368     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
369 
370     let method_impl = quote_spanned! { impl_span =>
371         fn from_args(__cmd_name: &[&str], __args: &[&str])
372             -> std::result::Result<Self, argh::EarlyExit>
373         {
374             #![allow(clippy::unwrap_in_result)]
375 
376             #( #init_fields )*
377 
378             argh::parse_struct_args(
379                 __cmd_name,
380                 __args,
381                 argh::ParseStructOptions {
382                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
383                     slots: &mut [ #( #flag_output_table, )* ],
384                 },
385                 argh::ParseStructPositionals {
386                     positionals: &mut [
387                         #(
388                             argh::ParseStructPositional {
389                                 name: #positional_field_names,
390                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
391                             },
392                         )*
393                     ],
394                     last_is_repeating: #last_positional_is_repeating,
395                     last_is_greedy: #last_positional_is_greedy,
396                 },
397                 #parse_subcommands,
398                 &|| #help,
399             )?;
400 
401             let mut #missing_requirements_ident = argh::MissingRequirements::default();
402             #(
403                 #append_missing_requirements
404             )*
405             #missing_requirements_ident.err_on_any()?;
406 
407             Ok(Self {
408                 #( #unwrap_fields, )*
409             })
410         }
411     };
412 
413     method_impl
414 }
415 
impl_from_args_struct_redact_arg_values<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream416 fn impl_from_args_struct_redact_arg_values<'a>(
417     errors: &Errors,
418     type_attrs: &TypeAttrs,
419     fields: &'a [StructField<'a>],
420 ) -> TokenStream {
421     let init_fields = declare_local_storage_for_redacted_fields(fields);
422     let unwrap_fields = unwrap_redacted_fields(fields);
423 
424     let positional_fields: Vec<&StructField<'_>> =
425         fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
426     let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
427     let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
428     let last_positional_is_repeating = positional_fields
429         .last()
430         .map(|field| field.optionality == Optionality::Repeating)
431         .unwrap_or(false);
432     let last_positional_is_greedy = positional_fields
433         .last()
434         .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
435         .unwrap_or(false);
436 
437     let flag_output_table = fields.iter().filter_map(|field| {
438         let field_name = &field.field.ident;
439         match field.kind {
440             FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
441             FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
442             FieldKind::SubCommand | FieldKind::Positional => None,
443         }
444     });
445 
446     let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
447 
448     let mut subcommands_iter =
449         fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
450 
451     let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
452     for dup_subcommand in subcommands_iter {
453         errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
454     }
455 
456     let impl_span = Span::call_site();
457 
458     let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
459 
460     let append_missing_requirements =
461         append_missing_requirements(&missing_requirements_ident, fields);
462 
463     let redact_subcommands = if let Some(subcommand) = subcommand {
464         let name = subcommand.name;
465         let ty = subcommand.ty_without_wrapper;
466         quote_spanned! { impl_span =>
467             Some(argh::ParseStructSubCommand {
468                 subcommands: <#ty as argh::SubCommands>::COMMANDS,
469                 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
470                 parse_func: &mut |__command, __remaining_args| {
471                     #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
472                     Ok(())
473                 },
474             })
475         }
476     } else {
477         quote_spanned! { impl_span => None }
478     };
479 
480     let unwrap_cmd_name_err_string = if type_attrs.is_subcommand.is_none() {
481         quote! { "no command name" }
482     } else {
483         quote! { "no subcommand name" }
484     };
485 
486     // Identifier referring to a value containing the name of the current command as an `&[&str]`.
487     let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
488     let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
489 
490     let method_impl = quote_spanned! { impl_span =>
491         fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
492             #( #init_fields )*
493 
494             argh::parse_struct_args(
495                 __cmd_name,
496                 __args,
497                 argh::ParseStructOptions {
498                     arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
499                     slots: &mut [ #( #flag_output_table, )* ],
500                 },
501                 argh::ParseStructPositionals {
502                     positionals: &mut [
503                         #(
504                             argh::ParseStructPositional {
505                                 name: #positional_field_names,
506                                 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
507                             },
508                         )*
509                     ],
510                     last_is_repeating: #last_positional_is_repeating,
511                     last_is_greedy: #last_positional_is_greedy,
512                 },
513                 #redact_subcommands,
514                 &|| #help,
515             )?;
516 
517             let mut #missing_requirements_ident = argh::MissingRequirements::default();
518             #(
519                 #append_missing_requirements
520             )*
521             #missing_requirements_ident.err_on_any()?;
522 
523             let mut __redacted = vec![
524                 if let Some(cmd_name) = __cmd_name.last() {
525                     (*cmd_name).to_owned()
526                 } else {
527                     return Err(argh::EarlyExit::from(#unwrap_cmd_name_err_string.to_owned()));
528                 }
529             ];
530 
531             #( #unwrap_fields )*
532 
533             Ok(__redacted)
534         }
535     };
536 
537     method_impl
538 }
539 
540 /// Ensures that only the last positional arg is non-required.
ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>])541 fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
542     let mut first_non_required_span = None;
543     for field in fields {
544         if field.kind == FieldKind::Positional {
545             if let Some(first) = first_non_required_span {
546                 errors.err_span(
547                     first,
548                     "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
549                 );
550                 errors.err(&field.field, "Later positional argument declared here.");
551                 return;
552             }
553             if !field.optionality.is_required() {
554                 first_non_required_span = Some(field.field.span());
555             }
556         }
557     }
558 }
559 
560 /// Ensures that only one short or long name is used.
ensure_unique_names(errors: &Errors, fields: &[StructField<'_>])561 fn ensure_unique_names(errors: &Errors, fields: &[StructField<'_>]) {
562     let mut seen_short_names = HashMap::new();
563     let mut seen_long_names = HashMap::new();
564 
565     for field in fields {
566         if let Some(short_name) = &field.attrs.short {
567             let short_name = short_name.value();
568             if let Some(first_use_field) = seen_short_names.get(&short_name) {
569                 errors.err_span_tokens(
570                     first_use_field,
571                     &format!("The short name of \"-{}\" was already used here.", short_name),
572                 );
573                 errors.err_span_tokens(field.field, "Later usage here.");
574             }
575 
576             seen_short_names.insert(short_name, &field.field);
577         }
578 
579         if let Some(long_name) = &field.long_name {
580             if let Some(first_use_field) = seen_long_names.get(&long_name) {
581                 errors.err_span_tokens(
582                     *first_use_field,
583                     &format!("The long name of \"{}\" was already used here.", long_name),
584                 );
585                 errors.err_span_tokens(field.field, "Later usage here.");
586             }
587 
588             seen_long_names.insert(long_name, field.field);
589         }
590     }
591 }
592 
593 /// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
top_or_sub_cmd_impl( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ) -> TokenStream594 fn top_or_sub_cmd_impl(
595     errors: &Errors,
596     name: &syn::Ident,
597     type_attrs: &TypeAttrs,
598     generic_args: &syn::Generics,
599 ) -> TokenStream {
600     let description =
601         help::require_description(errors, name.span(), &type_attrs.description, "type");
602     let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
603     if type_attrs.is_subcommand.is_none() {
604         // Not a subcommand
605         quote! {
606             #[automatically_derived]
607             impl #impl_generics argh::TopLevelCommand for #name #ty_generics #where_clause {}
608         }
609     } else {
610         let empty_str = syn::LitStr::new("", Span::call_site());
611         let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
612             errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
613             &empty_str
614         });
615         quote! {
616             #[automatically_derived]
617             impl #impl_generics argh::SubCommand for #name #ty_generics #where_clause {
618                 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
619                     name: #subcommand_name,
620                     description: #description,
621                 };
622             }
623         }
624     }
625 }
626 
627 /// Declare a local slots to store each field in during parsing.
628 ///
629 /// Most fields are stored in `Option<FieldType>` locals.
630 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
631 /// function that knows how to decode the appropriate value.
declare_local_storage_for_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a632 fn declare_local_storage_for_from_args_fields<'a>(
633     fields: &'a [StructField<'a>],
634 ) -> impl Iterator<Item = TokenStream> + 'a {
635     fields.iter().map(|field| {
636         let field_name = &field.field.ident;
637         let field_type = &field.ty_without_wrapper;
638 
639         // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
640         let field_slot_type = match field.optionality {
641             Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
642             Optionality::None | Optionality::Defaulted(_) => {
643                 quote! { std::option::Option<#field_type> }
644             }
645         };
646 
647         match field.kind {
648             FieldKind::Option | FieldKind::Positional => {
649                 let from_str_fn = match &field.attrs.from_str_fn {
650                     Some(from_str_fn) => from_str_fn.into_token_stream(),
651                     None => {
652                         quote! {
653                             <#field_type as argh::FromArgValue>::from_arg_value
654                         }
655                     }
656                 };
657 
658                 quote! {
659                     let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
660                         = argh::ParseValueSlotTy {
661                             slot: std::default::Default::default(),
662                             parse_func: |_, value| { #from_str_fn(value) },
663                         };
664                 }
665             }
666             FieldKind::SubCommand => {
667                 quote! { let mut #field_name: #field_slot_type = None; }
668             }
669             FieldKind::Switch => {
670                 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
671             }
672         }
673     })
674 }
675 
676 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a677 fn unwrap_from_args_fields<'a>(
678     fields: &'a [StructField<'a>],
679 ) -> impl Iterator<Item = TokenStream> + 'a {
680     fields.iter().map(|field| {
681         let field_name = field.name;
682         match field.kind {
683             FieldKind::Option | FieldKind::Positional => match &field.optionality {
684                 Optionality::None => quote! {
685                     #field_name: #field_name.slot.unwrap()
686                 },
687                 Optionality::Optional | Optionality::Repeating => {
688                     quote! { #field_name: #field_name.slot }
689                 }
690                 Optionality::Defaulted(tokens) => {
691                     quote! {
692                         #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
693                     }
694                 }
695             },
696             FieldKind::Switch => field_name.into_token_stream(),
697             FieldKind::SubCommand => match field.optionality {
698                 Optionality::None => quote! { #field_name: #field_name.unwrap() },
699                 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
700                 Optionality::Defaulted(_) => unreachable!(),
701             },
702         }
703     })
704 }
705 
706 /// Declare a local slots to store each field in during parsing.
707 ///
708 /// Most fields are stored in `Option<FieldType>` locals.
709 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
710 /// function that knows how to decode the appropriate value.
declare_local_storage_for_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a711 fn declare_local_storage_for_redacted_fields<'a>(
712     fields: &'a [StructField<'a>],
713 ) -> impl Iterator<Item = TokenStream> + 'a {
714     fields.iter().map(|field| {
715         let field_name = &field.field.ident;
716 
717         match field.kind {
718             FieldKind::Switch => {
719                 quote! {
720                     let mut #field_name = argh::RedactFlag {
721                         slot: None,
722                     };
723                 }
724             }
725             FieldKind::Option => {
726                 let field_slot_type = match field.optionality {
727                     Optionality::Repeating => {
728                         quote! { std::vec::Vec<String> }
729                     }
730                     Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
731                         quote! { std::option::Option<String> }
732                     }
733                 };
734 
735                 quote! {
736                     let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
737                         argh::ParseValueSlotTy {
738                         slot: std::default::Default::default(),
739                         parse_func: |arg, _| { Ok(arg.to_owned()) },
740                     };
741                 }
742             }
743             FieldKind::Positional => {
744                 let field_slot_type = match field.optionality {
745                     Optionality::Repeating => {
746                         quote! { std::vec::Vec<String> }
747                     }
748                     Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
749                         quote! { std::option::Option<String> }
750                     }
751                 };
752 
753                 let arg_name = field.positional_arg_name();
754                 quote! {
755                     let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
756                         argh::ParseValueSlotTy {
757                         slot: std::default::Default::default(),
758                         parse_func: |_, _| { Ok(#arg_name.to_owned()) },
759                     };
760                 }
761             }
762             FieldKind::SubCommand => {
763                 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
764             }
765         }
766     })
767 }
768 
769 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a770 fn unwrap_redacted_fields<'a>(
771     fields: &'a [StructField<'a>],
772 ) -> impl Iterator<Item = TokenStream> + 'a {
773     fields.iter().map(|field| {
774         let field_name = field.name;
775 
776         match field.kind {
777             FieldKind::Switch => {
778                 quote! {
779                     if let Some(__field_name) = #field_name.slot {
780                         __redacted.push(__field_name);
781                     }
782                 }
783             }
784             FieldKind::Option => match field.optionality {
785                 Optionality::Repeating => {
786                     quote! {
787                         __redacted.extend(#field_name.slot.into_iter());
788                     }
789                 }
790                 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
791                     quote! {
792                         if let Some(__field_name) = #field_name.slot {
793                             __redacted.push(__field_name);
794                         }
795                     }
796                 }
797             },
798             FieldKind::Positional => {
799                 quote! {
800                     __redacted.extend(#field_name.slot.into_iter());
801                 }
802             }
803             FieldKind::SubCommand => {
804                 quote! {
805                     if let Some(__subcommand_args) = #field_name {
806                         __redacted.extend(__subcommand_args.into_iter());
807                     }
808                 }
809             }
810         }
811     })
812 }
813 
814 /// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
815 /// to an index in the output table.
flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream>816 fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
817     let mut flag_str_to_output_table_map = vec![];
818     for (i, (field, long_name)) in fields
819         .iter()
820         .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
821         .enumerate()
822     {
823         if let Some(short) = &field.attrs.short {
824             let short = format!("-{}", short.value());
825             flag_str_to_output_table_map.push(quote! { (#short, #i) });
826         }
827 
828         flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
829     }
830     flag_str_to_output_table_map
831 }
832 
833 /// For each non-optional field, add an entry to the `argh::MissingRequirements`.
append_missing_requirements<'a>( mri: &syn::Ident, fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a834 fn append_missing_requirements<'a>(
835     // missing_requirements_ident
836     mri: &syn::Ident,
837     fields: &'a [StructField<'a>],
838 ) -> impl Iterator<Item = TokenStream> + 'a {
839     let mri = mri.clone();
840     fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
841         let field_name = field.name;
842         match field.kind {
843             FieldKind::Switch => unreachable!("switches are always optional"),
844             FieldKind::Positional => {
845                 let name = field.positional_arg_name();
846                 quote! {
847                     if #field_name.slot.is_none() {
848                         #mri.missing_positional_arg(#name)
849                     }
850                 }
851             }
852             FieldKind::Option => {
853                 let name = field.long_name.as_ref().expect("options always have a long name");
854                 quote! {
855                     if #field_name.slot.is_none() {
856                         #mri.missing_option(#name)
857                     }
858                 }
859             }
860             FieldKind::SubCommand => {
861                 let ty = field.ty_without_wrapper;
862                 quote! {
863                     if #field_name.is_none() {
864                         #mri.missing_subcommands(
865                             <#ty as argh::SubCommands>::COMMANDS
866                                 .iter()
867                                 .cloned()
868                                 .chain(
869                                     <#ty as argh::SubCommands>::dynamic_commands()
870                                         .iter()
871                                         .copied()
872                                 ),
873                         )
874                     }
875                 }
876             }
877         }
878     })
879 }
880 
881 /// Require that a type can be a `switch`.
882 /// Throws an error for all types except booleans and integers
ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool883 fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
884     fn ty_can_be_switch(ty: &syn::Type) -> bool {
885         if let syn::Type::Path(path) = ty {
886             if path.qself.is_some() {
887                 return false;
888             }
889             if path.path.segments.len() != 1 {
890                 return false;
891             }
892             let ident = &path.path.segments[0].ident;
893             // `Option<bool>` can be used as a `switch`.
894             if ident == "Option" {
895                 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
896                     if let GenericArgument::Type(Type::Path(p)) = &args.args[0] {
897                         if p.path.segments[0].ident == "bool" {
898                             return true;
899                         }
900                     }
901                 }
902             }
903             ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
904                 .iter()
905                 .any(|path| ident == path)
906         } else {
907             false
908         }
909     }
910 
911     let res = ty_can_be_switch(ty);
912     if !res {
913         errors.err(ty, "switches must be of type `bool`, `Option<bool>`, or integer type");
914     }
915     res
916 }
917 
918 /// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type>919 fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
920     if let syn::Type::Path(path) = ty {
921         if path.qself.is_some() {
922             return None;
923         }
924         // Since we only check the last path segment, it isn't necessarily the case that
925         // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
926         // a fool proof way to check these since name resolution happens after macro expansion,
927         // so this is likely "good enough" (so long as people don't have their own types called
928         // `Option` or `Vec` that take one generic parameter they're looking to parse).
929         let last_segment = path.path.segments.last()?;
930         if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
931             return None;
932         }
933         if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
934             let generic_arg = gen_args.args.first()?;
935             if let syn::GenericArgument::Type(ty) = &generic_arg {
936                 return Some(ty);
937             }
938         }
939     }
940     None
941 }
942 
943 /// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
impl_from_args_enum( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, de: &syn::DataEnum, ) -> TokenStream944 fn impl_from_args_enum(
945     errors: &Errors,
946     name: &syn::Ident,
947     type_attrs: &TypeAttrs,
948     generic_args: &syn::Generics,
949     de: &syn::DataEnum,
950 ) -> TokenStream {
951     parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
952 
953     // An enum variant like `<name>(<ty>)`
954     struct SubCommandVariant<'a> {
955         name: &'a syn::Ident,
956         ty: &'a syn::Type,
957     }
958 
959     let mut dynamic_type_and_variant = None;
960 
961     let variants: Vec<SubCommandVariant<'_>> = de
962         .variants
963         .iter()
964         .filter_map(|variant| {
965             let name = &variant.ident;
966             let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
967             if parse_attrs::VariantAttrs::parse(errors, variant).is_dynamic.is_some() {
968                 if dynamic_type_and_variant.is_some() {
969                     errors.err(variant, "Only one variant can have the `dynamic` attribute");
970                 }
971                 dynamic_type_and_variant = Some((ty, name));
972                 None
973             } else {
974                 Some(SubCommandVariant { name, ty })
975             }
976         })
977         .collect();
978 
979     let name_repeating = std::iter::repeat(name.clone());
980     let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
981     let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
982     let dynamic_from_args =
983         dynamic_type_and_variant.as_ref().map(|(dynamic_type, dynamic_variant)| {
984             quote! {
985                 if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_from_args(
986                     command_name, args) {
987                     return result.map(#name::#dynamic_variant);
988                 }
989             }
990         });
991     let dynamic_redact_arg_values = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
992         quote! {
993             if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_redact_arg_values(
994                 command_name, args) {
995                 return result;
996             }
997         }
998     });
999     let dynamic_commands = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
1000         quote! {
1001             fn dynamic_commands() -> &'static [&'static argh::CommandInfo] {
1002                 <#dynamic_type as argh::DynamicSubCommand>::commands()
1003             }
1004         }
1005     });
1006 
1007     let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
1008     quote! {
1009         impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
1010             fn from_args(command_name: &[&str], args: &[&str])
1011                 -> std::result::Result<Self, argh::EarlyExit>
1012             {
1013                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1014                     *subcommand_name
1015                 } else {
1016                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1017                 };
1018 
1019                 #(
1020                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1021                         return Ok(#name_repeating::#variant_names(
1022                             <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
1023                         ));
1024                     }
1025                 )*
1026 
1027                 #dynamic_from_args
1028 
1029                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1030             }
1031 
1032             fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
1033                 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1034                     *subcommand_name
1035                 } else {
1036                     return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1037                 };
1038 
1039                 #(
1040                     if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1041                         return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
1042                     }
1043                 )*
1044 
1045                 #dynamic_redact_arg_values
1046 
1047                 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1048             }
1049         }
1050 
1051         impl #impl_generics argh::SubCommands for #name #ty_generics #where_clause {
1052             const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
1053                 <#variant_ty as argh::SubCommand>::COMMAND,
1054             )*];
1055 
1056             #dynamic_commands
1057         }
1058     }
1059 }
1060 
1061 /// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
1062 /// Otherwise, generates an error.
enum_only_single_field_unnamed_variants<'a>( errors: &Errors, variant_fields: &'a syn::Fields, ) -> Option<&'a syn::Type>1063 fn enum_only_single_field_unnamed_variants<'a>(
1064     errors: &Errors,
1065     variant_fields: &'a syn::Fields,
1066 ) -> Option<&'a syn::Type> {
1067     macro_rules! with_enum_suggestion {
1068         ($help_text:literal) => {
1069             concat!(
1070                 $help_text,
1071                 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
1072                 "    enum MyCommandEnum {\n",
1073                 "        SubCommandOne(SubCommandOne),\n",
1074                 "        SubCommandTwo(SubCommandTwo),\n",
1075                 "    }",
1076             )
1077         };
1078     }
1079 
1080     match variant_fields {
1081         syn::Fields::Named(fields) => {
1082             errors.err(
1083                 fields,
1084                 with_enum_suggestion!(
1085                     "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
1086                 ),
1087             );
1088             None
1089         }
1090         syn::Fields::Unit => {
1091             errors.err(
1092                 variant_fields,
1093                 with_enum_suggestion!(
1094                     "`#![derive(FromArgs)]` does not support `enum`s with no variants."
1095                 ),
1096             );
1097             None
1098         }
1099         syn::Fields::Unnamed(fields) => {
1100             if fields.unnamed.len() != 1 {
1101                 errors.err(
1102                     fields,
1103                     with_enum_suggestion!(
1104                         "`#![derive(FromArgs)]` `enum` variants must only contain one field."
1105                     ),
1106                 );
1107                 None
1108             } else {
1109                 // `unwrap` is okay because of the length check above.
1110                 let first_field = fields.unnamed.first().unwrap();
1111                 Some(&first_field.ty)
1112             }
1113         }
1114     }
1115 }
1116