1 #![recursion_limit = "256"]
2 // Copyright (c) 2020 Google LLC All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 /// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7 ///
8 /// For more thorough documentation, see the `argh` crate itself.
9 extern crate proc_macro;
10
11 use {
12 crate::{
13 errors::Errors,
14 parse_attrs::{check_long_name, FieldAttrs, FieldKind, TypeAttrs},
15 },
16 proc_macro2::{Span, TokenStream},
17 quote::{quote, quote_spanned, ToTokens},
18 std::{collections::HashMap, str::FromStr},
19 syn::{spanned::Spanned, GenericArgument, LitStr, PathArguments, Type},
20 };
21
22 mod args_info;
23 mod errors;
24 mod help;
25 mod parse_attrs;
26
27 /// Entrypoint for `#[derive(FromArgs)]`.
28 #[proc_macro_derive(FromArgs, attributes(argh))]
argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream29 pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
31 let gen = impl_from_args(&ast);
32 gen.into()
33 }
34
35 /// Entrypoint for `#[derive(ArgsInfo)]`.
36 #[proc_macro_derive(ArgsInfo, attributes(argh))]
args_info_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream37 pub fn args_info_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
38 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
39 let gen = args_info::impl_args_info(&ast);
40 gen.into()
41 }
42
43 /// Transform the input into a token stream containing any generated implementations,
44 /// as well as all errors that occurred.
impl_from_args(input: &syn::DeriveInput) -> TokenStream45 fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
46 let errors = &Errors::default();
47 let type_attrs = &TypeAttrs::parse(errors, input);
48 let mut output_tokens = match &input.data {
49 syn::Data::Struct(ds) => {
50 impl_from_args_struct(errors, &input.ident, type_attrs, &input.generics, ds)
51 }
52 syn::Data::Enum(de) => {
53 impl_from_args_enum(errors, &input.ident, type_attrs, &input.generics, de)
54 }
55 syn::Data::Union(_) => {
56 errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
57 TokenStream::new()
58 }
59 };
60 errors.to_tokens(&mut output_tokens);
61 output_tokens
62 }
63
64 /// The kind of optionality a parameter has.
65 enum Optionality {
66 None,
67 Defaulted(TokenStream),
68 Optional,
69 Repeating,
70 }
71
72 impl PartialEq<Optionality> for Optionality {
eq(&self, other: &Optionality) -> bool73 fn eq(&self, other: &Optionality) -> bool {
74 use Optionality::*;
75 // NB: (Defaulted, Defaulted) can't contain the same token streams
76 matches!((self, other), (Optional, Optional) | (Repeating, Repeating))
77 }
78 }
79
80 impl Optionality {
81 /// Whether or not this is `Optionality::None`
is_required(&self) -> bool82 fn is_required(&self) -> bool {
83 matches!(self, Optionality::None)
84 }
85 }
86
87 /// A field of a `#![derive(FromArgs)]` struct with attributes and some other
88 /// notable metadata appended.
89 struct StructField<'a> {
90 /// The original parsed field
91 field: &'a syn::Field,
92 /// The parsed attributes of the field
93 attrs: FieldAttrs,
94 /// The field name. This is contained optionally inside `field`,
95 /// but is duplicated non-optionally here to indicate that all field that
96 /// have reached this point must have a field name, and it no longer
97 /// needs to be unwrapped.
98 name: &'a syn::Ident,
99 /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
100 /// but here is fully present to indicate that we only have to consider fields
101 /// with a valid `kind` at this point.
102 kind: FieldKind,
103 // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
104 // This is used to enable consistent parsing code between optional and non-optional
105 // keyed and subcommand fields.
106 ty_without_wrapper: &'a syn::Type,
107 // Whether the field represents an optional value, such as an `Option` subcommand field
108 // or an `Option` or `Vec` keyed argument, or if it has a `default`.
109 optionality: Optionality,
110 // The `--`-prefixed name of the option, if one exists.
111 long_name: Option<String>,
112 }
113
114 impl<'a> StructField<'a> {
115 /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
116 /// fields required for code generation.
new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self>117 fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
118 let name = field.ident.as_ref().expect("missing ident for named field");
119
120 // Ensure that one "kind" is present (switch, option, subcommand, positional)
121 let kind = if let Some(field_type) = &attrs.field_type {
122 field_type.kind
123 } else {
124 errors.err(
125 field,
126 concat!(
127 "Missing `argh` field kind attribute.\n",
128 "Expected one of: `switch`, `option`, `remaining`, `subcommand`, `positional`",
129 ),
130 );
131 return None;
132 };
133
134 // Parse out whether a field is optional (`Option` or `Vec`).
135 let optionality;
136 let ty_without_wrapper;
137 match kind {
138 FieldKind::Switch => {
139 if !ty_expect_switch(errors, &field.ty) {
140 return None;
141 }
142 optionality = Optionality::Optional;
143 ty_without_wrapper = &field.ty;
144 }
145 FieldKind::Option | FieldKind::Positional => {
146 if let Some(default) = &attrs.default {
147 let tokens = match TokenStream::from_str(&default.value()) {
148 Ok(tokens) => tokens,
149 Err(_) => {
150 errors.err(&default, "Invalid tokens: unable to lex `default` value");
151 return None;
152 }
153 };
154 // Set the span of the generated tokens to the string literal
155 let tokens: TokenStream = tokens
156 .into_iter()
157 .map(|mut tree| {
158 tree.set_span(default.span());
159 tree
160 })
161 .collect();
162 optionality = Optionality::Defaulted(tokens);
163 ty_without_wrapper = &field.ty;
164 } else {
165 let mut inner = None;
166 optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
167 inner = Some(x);
168 Optionality::Optional
169 } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
170 inner = Some(x);
171 Optionality::Repeating
172 } else {
173 Optionality::None
174 };
175 ty_without_wrapper = inner.unwrap_or(&field.ty);
176 }
177 }
178 FieldKind::SubCommand => {
179 let inner = ty_inner(&["Option"], &field.ty);
180 optionality =
181 if inner.is_some() { Optionality::Optional } else { Optionality::None };
182 ty_without_wrapper = inner.unwrap_or(&field.ty);
183 }
184 }
185
186 // Determine the "long" name of options and switches.
187 // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
188 let long_name = match kind {
189 FieldKind::Switch | FieldKind::Option => {
190 let long_name = attrs.long.as_ref().map(syn::LitStr::value).unwrap_or_else(|| {
191 let kebab_name = to_kebab_case(&name.to_string());
192 check_long_name(errors, name, &kebab_name);
193 kebab_name
194 });
195 if long_name == "help" {
196 errors.err(field, "Custom `--help` flags are not supported.");
197 }
198 let long_name = format!("--{}", long_name);
199 Some(long_name)
200 }
201 FieldKind::SubCommand | FieldKind::Positional => None,
202 };
203
204 Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
205 }
206
positional_arg_name(&self) -> String207 pub(crate) fn positional_arg_name(&self) -> String {
208 self.attrs
209 .arg_name
210 .as_ref()
211 .map(LitStr::value)
212 .unwrap_or_else(|| self.name.to_string().trim_matches('_').to_owned())
213 }
214 }
215
to_kebab_case(s: &str) -> String216 fn to_kebab_case(s: &str) -> String {
217 let words = s.split('_').filter(|word| !word.is_empty());
218 let mut res = String::with_capacity(s.len());
219 for word in words {
220 if !res.is_empty() {
221 res.push('-')
222 }
223 res.push_str(word)
224 }
225 res
226 }
227
228 #[test]
test_kebabs()229 fn test_kebabs() {
230 #[track_caller]
231 fn check(s: &str, want: &str) {
232 let got = to_kebab_case(s);
233 assert_eq!(got.as_str(), want)
234 }
235 check("", "");
236 check("_", "");
237 check("foo", "foo");
238 check("__foo_", "foo");
239 check("foo_bar", "foo-bar");
240 check("foo__Bar", "foo-Bar");
241 check("foo_bar__baz_", "foo-bar-baz");
242 }
243
244 /// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
impl_from_args_struct( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ds: &syn::DataStruct, ) -> TokenStream245 fn impl_from_args_struct(
246 errors: &Errors,
247 name: &syn::Ident,
248 type_attrs: &TypeAttrs,
249 generic_args: &syn::Generics,
250 ds: &syn::DataStruct,
251 ) -> TokenStream {
252 let fields = match &ds.fields {
253 syn::Fields::Named(fields) => fields,
254 syn::Fields::Unnamed(_) => {
255 errors.err(
256 &ds.struct_token,
257 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
258 );
259 return TokenStream::new();
260 }
261 syn::Fields::Unit => {
262 errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
263 return TokenStream::new();
264 }
265 };
266
267 let fields: Vec<_> = fields
268 .named
269 .iter()
270 .filter_map(|field| {
271 let attrs = FieldAttrs::parse(errors, field);
272 StructField::new(errors, field, attrs)
273 })
274 .collect();
275
276 ensure_unique_names(errors, &fields);
277 ensure_only_last_positional_is_optional(errors, &fields);
278
279 let impl_span = Span::call_site();
280
281 let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
282
283 let redact_arg_values_method =
284 impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
285
286 let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs, generic_args);
287
288 let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
289 let trait_impl = quote_spanned! { impl_span =>
290 #[automatically_derived]
291 impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
292 #from_args_method
293
294 #redact_arg_values_method
295 }
296
297 #top_or_sub_cmd_impl
298 };
299
300 trait_impl
301 }
302
impl_from_args_struct_from_args<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream303 fn impl_from_args_struct_from_args<'a>(
304 errors: &Errors,
305 type_attrs: &TypeAttrs,
306 fields: &'a [StructField<'a>],
307 ) -> TokenStream {
308 let init_fields = declare_local_storage_for_from_args_fields(fields);
309 let unwrap_fields = unwrap_from_args_fields(fields);
310 let positional_fields: Vec<&StructField<'_>> =
311 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
312 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
313 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
314 let last_positional_is_repeating = positional_fields
315 .last()
316 .map(|field| field.optionality == Optionality::Repeating)
317 .unwrap_or(false);
318 let last_positional_is_greedy = positional_fields
319 .last()
320 .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
321 .unwrap_or(false);
322
323 let flag_output_table = fields.iter().filter_map(|field| {
324 let field_name = &field.field.ident;
325 match field.kind {
326 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
327 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
328 FieldKind::SubCommand | FieldKind::Positional => None,
329 }
330 });
331
332 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
333
334 let mut subcommands_iter =
335 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
336
337 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
338 for dup_subcommand in subcommands_iter {
339 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
340 }
341
342 let impl_span = Span::call_site();
343
344 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
345
346 let append_missing_requirements =
347 append_missing_requirements(&missing_requirements_ident, fields);
348
349 let parse_subcommands = if let Some(subcommand) = subcommand {
350 let name = subcommand.name;
351 let ty = subcommand.ty_without_wrapper;
352 quote_spanned! { impl_span =>
353 Some(argh::ParseStructSubCommand {
354 subcommands: <#ty as argh::SubCommands>::COMMANDS,
355 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
356 parse_func: &mut |__command, __remaining_args| {
357 #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
358 Ok(())
359 },
360 })
361 }
362 } else {
363 quote_spanned! { impl_span => None }
364 };
365
366 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
367 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
368 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
369
370 let method_impl = quote_spanned! { impl_span =>
371 fn from_args(__cmd_name: &[&str], __args: &[&str])
372 -> std::result::Result<Self, argh::EarlyExit>
373 {
374 #![allow(clippy::unwrap_in_result)]
375
376 #( #init_fields )*
377
378 argh::parse_struct_args(
379 __cmd_name,
380 __args,
381 argh::ParseStructOptions {
382 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
383 slots: &mut [ #( #flag_output_table, )* ],
384 },
385 argh::ParseStructPositionals {
386 positionals: &mut [
387 #(
388 argh::ParseStructPositional {
389 name: #positional_field_names,
390 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
391 },
392 )*
393 ],
394 last_is_repeating: #last_positional_is_repeating,
395 last_is_greedy: #last_positional_is_greedy,
396 },
397 #parse_subcommands,
398 &|| #help,
399 )?;
400
401 let mut #missing_requirements_ident = argh::MissingRequirements::default();
402 #(
403 #append_missing_requirements
404 )*
405 #missing_requirements_ident.err_on_any()?;
406
407 Ok(Self {
408 #( #unwrap_fields, )*
409 })
410 }
411 };
412
413 method_impl
414 }
415
impl_from_args_struct_redact_arg_values<'a>( errors: &Errors, type_attrs: &TypeAttrs, fields: &'a [StructField<'a>], ) -> TokenStream416 fn impl_from_args_struct_redact_arg_values<'a>(
417 errors: &Errors,
418 type_attrs: &TypeAttrs,
419 fields: &'a [StructField<'a>],
420 ) -> TokenStream {
421 let init_fields = declare_local_storage_for_redacted_fields(fields);
422 let unwrap_fields = unwrap_redacted_fields(fields);
423
424 let positional_fields: Vec<&StructField<'_>> =
425 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
426 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
427 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
428 let last_positional_is_repeating = positional_fields
429 .last()
430 .map(|field| field.optionality == Optionality::Repeating)
431 .unwrap_or(false);
432 let last_positional_is_greedy = positional_fields
433 .last()
434 .map(|field| field.kind == FieldKind::Positional && field.attrs.greedy.is_some())
435 .unwrap_or(false);
436
437 let flag_output_table = fields.iter().filter_map(|field| {
438 let field_name = &field.field.ident;
439 match field.kind {
440 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
441 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
442 FieldKind::SubCommand | FieldKind::Positional => None,
443 }
444 });
445
446 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(fields);
447
448 let mut subcommands_iter =
449 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
450
451 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
452 for dup_subcommand in subcommands_iter {
453 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
454 }
455
456 let impl_span = Span::call_site();
457
458 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
459
460 let append_missing_requirements =
461 append_missing_requirements(&missing_requirements_ident, fields);
462
463 let redact_subcommands = if let Some(subcommand) = subcommand {
464 let name = subcommand.name;
465 let ty = subcommand.ty_without_wrapper;
466 quote_spanned! { impl_span =>
467 Some(argh::ParseStructSubCommand {
468 subcommands: <#ty as argh::SubCommands>::COMMANDS,
469 dynamic_subcommands: &<#ty as argh::SubCommands>::dynamic_commands(),
470 parse_func: &mut |__command, __remaining_args| {
471 #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
472 Ok(())
473 },
474 })
475 }
476 } else {
477 quote_spanned! { impl_span => None }
478 };
479
480 let unwrap_cmd_name_err_string = if type_attrs.is_subcommand.is_none() {
481 quote! { "no command name" }
482 } else {
483 quote! { "no subcommand name" }
484 };
485
486 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
487 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
488 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, fields, subcommand);
489
490 let method_impl = quote_spanned! { impl_span =>
491 fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
492 #( #init_fields )*
493
494 argh::parse_struct_args(
495 __cmd_name,
496 __args,
497 argh::ParseStructOptions {
498 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
499 slots: &mut [ #( #flag_output_table, )* ],
500 },
501 argh::ParseStructPositionals {
502 positionals: &mut [
503 #(
504 argh::ParseStructPositional {
505 name: #positional_field_names,
506 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
507 },
508 )*
509 ],
510 last_is_repeating: #last_positional_is_repeating,
511 last_is_greedy: #last_positional_is_greedy,
512 },
513 #redact_subcommands,
514 &|| #help,
515 )?;
516
517 let mut #missing_requirements_ident = argh::MissingRequirements::default();
518 #(
519 #append_missing_requirements
520 )*
521 #missing_requirements_ident.err_on_any()?;
522
523 let mut __redacted = vec![
524 if let Some(cmd_name) = __cmd_name.last() {
525 (*cmd_name).to_owned()
526 } else {
527 return Err(argh::EarlyExit::from(#unwrap_cmd_name_err_string.to_owned()));
528 }
529 ];
530
531 #( #unwrap_fields )*
532
533 Ok(__redacted)
534 }
535 };
536
537 method_impl
538 }
539
540 /// Ensures that only the last positional arg is non-required.
ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>])541 fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
542 let mut first_non_required_span = None;
543 for field in fields {
544 if field.kind == FieldKind::Positional {
545 if let Some(first) = first_non_required_span {
546 errors.err_span(
547 first,
548 "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
549 );
550 errors.err(&field.field, "Later positional argument declared here.");
551 return;
552 }
553 if !field.optionality.is_required() {
554 first_non_required_span = Some(field.field.span());
555 }
556 }
557 }
558 }
559
560 /// Ensures that only one short or long name is used.
ensure_unique_names(errors: &Errors, fields: &[StructField<'_>])561 fn ensure_unique_names(errors: &Errors, fields: &[StructField<'_>]) {
562 let mut seen_short_names = HashMap::new();
563 let mut seen_long_names = HashMap::new();
564
565 for field in fields {
566 if let Some(short_name) = &field.attrs.short {
567 let short_name = short_name.value();
568 if let Some(first_use_field) = seen_short_names.get(&short_name) {
569 errors.err_span_tokens(
570 first_use_field,
571 &format!("The short name of \"-{}\" was already used here.", short_name),
572 );
573 errors.err_span_tokens(field.field, "Later usage here.");
574 }
575
576 seen_short_names.insert(short_name, &field.field);
577 }
578
579 if let Some(long_name) = &field.long_name {
580 if let Some(first_use_field) = seen_long_names.get(&long_name) {
581 errors.err_span_tokens(
582 *first_use_field,
583 &format!("The long name of \"{}\" was already used here.", long_name),
584 );
585 errors.err_span_tokens(field.field, "Later usage here.");
586 }
587
588 seen_long_names.insert(long_name, field.field);
589 }
590 }
591 }
592
593 /// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
top_or_sub_cmd_impl( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, ) -> TokenStream594 fn top_or_sub_cmd_impl(
595 errors: &Errors,
596 name: &syn::Ident,
597 type_attrs: &TypeAttrs,
598 generic_args: &syn::Generics,
599 ) -> TokenStream {
600 let description =
601 help::require_description(errors, name.span(), &type_attrs.description, "type");
602 let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
603 if type_attrs.is_subcommand.is_none() {
604 // Not a subcommand
605 quote! {
606 #[automatically_derived]
607 impl #impl_generics argh::TopLevelCommand for #name #ty_generics #where_clause {}
608 }
609 } else {
610 let empty_str = syn::LitStr::new("", Span::call_site());
611 let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
612 errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
613 &empty_str
614 });
615 quote! {
616 #[automatically_derived]
617 impl #impl_generics argh::SubCommand for #name #ty_generics #where_clause {
618 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
619 name: #subcommand_name,
620 description: #description,
621 };
622 }
623 }
624 }
625 }
626
627 /// Declare a local slots to store each field in during parsing.
628 ///
629 /// Most fields are stored in `Option<FieldType>` locals.
630 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
631 /// function that knows how to decode the appropriate value.
declare_local_storage_for_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a632 fn declare_local_storage_for_from_args_fields<'a>(
633 fields: &'a [StructField<'a>],
634 ) -> impl Iterator<Item = TokenStream> + 'a {
635 fields.iter().map(|field| {
636 let field_name = &field.field.ident;
637 let field_type = &field.ty_without_wrapper;
638
639 // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
640 let field_slot_type = match field.optionality {
641 Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
642 Optionality::None | Optionality::Defaulted(_) => {
643 quote! { std::option::Option<#field_type> }
644 }
645 };
646
647 match field.kind {
648 FieldKind::Option | FieldKind::Positional => {
649 let from_str_fn = match &field.attrs.from_str_fn {
650 Some(from_str_fn) => from_str_fn.into_token_stream(),
651 None => {
652 quote! {
653 <#field_type as argh::FromArgValue>::from_arg_value
654 }
655 }
656 };
657
658 quote! {
659 let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
660 = argh::ParseValueSlotTy {
661 slot: std::default::Default::default(),
662 parse_func: |_, value| { #from_str_fn(value) },
663 };
664 }
665 }
666 FieldKind::SubCommand => {
667 quote! { let mut #field_name: #field_slot_type = None; }
668 }
669 FieldKind::Switch => {
670 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
671 }
672 }
673 })
674 }
675
676 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_from_args_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a677 fn unwrap_from_args_fields<'a>(
678 fields: &'a [StructField<'a>],
679 ) -> impl Iterator<Item = TokenStream> + 'a {
680 fields.iter().map(|field| {
681 let field_name = field.name;
682 match field.kind {
683 FieldKind::Option | FieldKind::Positional => match &field.optionality {
684 Optionality::None => quote! {
685 #field_name: #field_name.slot.unwrap()
686 },
687 Optionality::Optional | Optionality::Repeating => {
688 quote! { #field_name: #field_name.slot }
689 }
690 Optionality::Defaulted(tokens) => {
691 quote! {
692 #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
693 }
694 }
695 },
696 FieldKind::Switch => field_name.into_token_stream(),
697 FieldKind::SubCommand => match field.optionality {
698 Optionality::None => quote! { #field_name: #field_name.unwrap() },
699 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
700 Optionality::Defaulted(_) => unreachable!(),
701 },
702 }
703 })
704 }
705
706 /// Declare a local slots to store each field in during parsing.
707 ///
708 /// Most fields are stored in `Option<FieldType>` locals.
709 /// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
710 /// function that knows how to decode the appropriate value.
declare_local_storage_for_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a711 fn declare_local_storage_for_redacted_fields<'a>(
712 fields: &'a [StructField<'a>],
713 ) -> impl Iterator<Item = TokenStream> + 'a {
714 fields.iter().map(|field| {
715 let field_name = &field.field.ident;
716
717 match field.kind {
718 FieldKind::Switch => {
719 quote! {
720 let mut #field_name = argh::RedactFlag {
721 slot: None,
722 };
723 }
724 }
725 FieldKind::Option => {
726 let field_slot_type = match field.optionality {
727 Optionality::Repeating => {
728 quote! { std::vec::Vec<String> }
729 }
730 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
731 quote! { std::option::Option<String> }
732 }
733 };
734
735 quote! {
736 let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
737 argh::ParseValueSlotTy {
738 slot: std::default::Default::default(),
739 parse_func: |arg, _| { Ok(arg.to_owned()) },
740 };
741 }
742 }
743 FieldKind::Positional => {
744 let field_slot_type = match field.optionality {
745 Optionality::Repeating => {
746 quote! { std::vec::Vec<String> }
747 }
748 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
749 quote! { std::option::Option<String> }
750 }
751 };
752
753 let arg_name = field.positional_arg_name();
754 quote! {
755 let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
756 argh::ParseValueSlotTy {
757 slot: std::default::Default::default(),
758 parse_func: |_, _| { Ok(#arg_name.to_owned()) },
759 };
760 }
761 }
762 FieldKind::SubCommand => {
763 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
764 }
765 }
766 })
767 }
768
769 /// Unwrap non-optional fields and take options out of their tuple slots.
unwrap_redacted_fields<'a>( fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a770 fn unwrap_redacted_fields<'a>(
771 fields: &'a [StructField<'a>],
772 ) -> impl Iterator<Item = TokenStream> + 'a {
773 fields.iter().map(|field| {
774 let field_name = field.name;
775
776 match field.kind {
777 FieldKind::Switch => {
778 quote! {
779 if let Some(__field_name) = #field_name.slot {
780 __redacted.push(__field_name);
781 }
782 }
783 }
784 FieldKind::Option => match field.optionality {
785 Optionality::Repeating => {
786 quote! {
787 __redacted.extend(#field_name.slot.into_iter());
788 }
789 }
790 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
791 quote! {
792 if let Some(__field_name) = #field_name.slot {
793 __redacted.push(__field_name);
794 }
795 }
796 }
797 },
798 FieldKind::Positional => {
799 quote! {
800 __redacted.extend(#field_name.slot.into_iter());
801 }
802 }
803 FieldKind::SubCommand => {
804 quote! {
805 if let Some(__subcommand_args) = #field_name {
806 __redacted.extend(__subcommand_args.into_iter());
807 }
808 }
809 }
810 }
811 })
812 }
813
814 /// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
815 /// to an index in the output table.
flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream>816 fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
817 let mut flag_str_to_output_table_map = vec![];
818 for (i, (field, long_name)) in fields
819 .iter()
820 .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
821 .enumerate()
822 {
823 if let Some(short) = &field.attrs.short {
824 let short = format!("-{}", short.value());
825 flag_str_to_output_table_map.push(quote! { (#short, #i) });
826 }
827
828 flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
829 }
830 flag_str_to_output_table_map
831 }
832
833 /// For each non-optional field, add an entry to the `argh::MissingRequirements`.
append_missing_requirements<'a>( mri: &syn::Ident, fields: &'a [StructField<'a>], ) -> impl Iterator<Item = TokenStream> + 'a834 fn append_missing_requirements<'a>(
835 // missing_requirements_ident
836 mri: &syn::Ident,
837 fields: &'a [StructField<'a>],
838 ) -> impl Iterator<Item = TokenStream> + 'a {
839 let mri = mri.clone();
840 fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
841 let field_name = field.name;
842 match field.kind {
843 FieldKind::Switch => unreachable!("switches are always optional"),
844 FieldKind::Positional => {
845 let name = field.positional_arg_name();
846 quote! {
847 if #field_name.slot.is_none() {
848 #mri.missing_positional_arg(#name)
849 }
850 }
851 }
852 FieldKind::Option => {
853 let name = field.long_name.as_ref().expect("options always have a long name");
854 quote! {
855 if #field_name.slot.is_none() {
856 #mri.missing_option(#name)
857 }
858 }
859 }
860 FieldKind::SubCommand => {
861 let ty = field.ty_without_wrapper;
862 quote! {
863 if #field_name.is_none() {
864 #mri.missing_subcommands(
865 <#ty as argh::SubCommands>::COMMANDS
866 .iter()
867 .cloned()
868 .chain(
869 <#ty as argh::SubCommands>::dynamic_commands()
870 .iter()
871 .copied()
872 ),
873 )
874 }
875 }
876 }
877 }
878 })
879 }
880
881 /// Require that a type can be a `switch`.
882 /// Throws an error for all types except booleans and integers
ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool883 fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
884 fn ty_can_be_switch(ty: &syn::Type) -> bool {
885 if let syn::Type::Path(path) = ty {
886 if path.qself.is_some() {
887 return false;
888 }
889 if path.path.segments.len() != 1 {
890 return false;
891 }
892 let ident = &path.path.segments[0].ident;
893 // `Option<bool>` can be used as a `switch`.
894 if ident == "Option" {
895 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
896 if let GenericArgument::Type(Type::Path(p)) = &args.args[0] {
897 if p.path.segments[0].ident == "bool" {
898 return true;
899 }
900 }
901 }
902 }
903 ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
904 .iter()
905 .any(|path| ident == path)
906 } else {
907 false
908 }
909 }
910
911 let res = ty_can_be_switch(ty);
912 if !res {
913 errors.err(ty, "switches must be of type `bool`, `Option<bool>`, or integer type");
914 }
915 res
916 }
917
918 /// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type>919 fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
920 if let syn::Type::Path(path) = ty {
921 if path.qself.is_some() {
922 return None;
923 }
924 // Since we only check the last path segment, it isn't necessarily the case that
925 // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
926 // a fool proof way to check these since name resolution happens after macro expansion,
927 // so this is likely "good enough" (so long as people don't have their own types called
928 // `Option` or `Vec` that take one generic parameter they're looking to parse).
929 let last_segment = path.path.segments.last()?;
930 if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
931 return None;
932 }
933 if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
934 let generic_arg = gen_args.args.first()?;
935 if let syn::GenericArgument::Type(ty) = &generic_arg {
936 return Some(ty);
937 }
938 }
939 }
940 None
941 }
942
943 /// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
impl_from_args_enum( errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs, generic_args: &syn::Generics, de: &syn::DataEnum, ) -> TokenStream944 fn impl_from_args_enum(
945 errors: &Errors,
946 name: &syn::Ident,
947 type_attrs: &TypeAttrs,
948 generic_args: &syn::Generics,
949 de: &syn::DataEnum,
950 ) -> TokenStream {
951 parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
952
953 // An enum variant like `<name>(<ty>)`
954 struct SubCommandVariant<'a> {
955 name: &'a syn::Ident,
956 ty: &'a syn::Type,
957 }
958
959 let mut dynamic_type_and_variant = None;
960
961 let variants: Vec<SubCommandVariant<'_>> = de
962 .variants
963 .iter()
964 .filter_map(|variant| {
965 let name = &variant.ident;
966 let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
967 if parse_attrs::VariantAttrs::parse(errors, variant).is_dynamic.is_some() {
968 if dynamic_type_and_variant.is_some() {
969 errors.err(variant, "Only one variant can have the `dynamic` attribute");
970 }
971 dynamic_type_and_variant = Some((ty, name));
972 None
973 } else {
974 Some(SubCommandVariant { name, ty })
975 }
976 })
977 .collect();
978
979 let name_repeating = std::iter::repeat(name.clone());
980 let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
981 let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
982 let dynamic_from_args =
983 dynamic_type_and_variant.as_ref().map(|(dynamic_type, dynamic_variant)| {
984 quote! {
985 if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_from_args(
986 command_name, args) {
987 return result.map(#name::#dynamic_variant);
988 }
989 }
990 });
991 let dynamic_redact_arg_values = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
992 quote! {
993 if let Some(result) = <#dynamic_type as argh::DynamicSubCommand>::try_redact_arg_values(
994 command_name, args) {
995 return result;
996 }
997 }
998 });
999 let dynamic_commands = dynamic_type_and_variant.as_ref().map(|(dynamic_type, _)| {
1000 quote! {
1001 fn dynamic_commands() -> &'static [&'static argh::CommandInfo] {
1002 <#dynamic_type as argh::DynamicSubCommand>::commands()
1003 }
1004 }
1005 });
1006
1007 let (impl_generics, ty_generics, where_clause) = generic_args.split_for_impl();
1008 quote! {
1009 impl #impl_generics argh::FromArgs for #name #ty_generics #where_clause {
1010 fn from_args(command_name: &[&str], args: &[&str])
1011 -> std::result::Result<Self, argh::EarlyExit>
1012 {
1013 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1014 *subcommand_name
1015 } else {
1016 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1017 };
1018
1019 #(
1020 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1021 return Ok(#name_repeating::#variant_names(
1022 <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
1023 ));
1024 }
1025 )*
1026
1027 #dynamic_from_args
1028
1029 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1030 }
1031
1032 fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
1033 let subcommand_name = if let Some(subcommand_name) = command_name.last() {
1034 *subcommand_name
1035 } else {
1036 return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
1037 };
1038
1039 #(
1040 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
1041 return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
1042 }
1043 )*
1044
1045 #dynamic_redact_arg_values
1046
1047 Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
1048 }
1049 }
1050
1051 impl #impl_generics argh::SubCommands for #name #ty_generics #where_clause {
1052 const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
1053 <#variant_ty as argh::SubCommand>::COMMAND,
1054 )*];
1055
1056 #dynamic_commands
1057 }
1058 }
1059 }
1060
1061 /// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
1062 /// Otherwise, generates an error.
enum_only_single_field_unnamed_variants<'a>( errors: &Errors, variant_fields: &'a syn::Fields, ) -> Option<&'a syn::Type>1063 fn enum_only_single_field_unnamed_variants<'a>(
1064 errors: &Errors,
1065 variant_fields: &'a syn::Fields,
1066 ) -> Option<&'a syn::Type> {
1067 macro_rules! with_enum_suggestion {
1068 ($help_text:literal) => {
1069 concat!(
1070 $help_text,
1071 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
1072 " enum MyCommandEnum {\n",
1073 " SubCommandOne(SubCommandOne),\n",
1074 " SubCommandTwo(SubCommandTwo),\n",
1075 " }",
1076 )
1077 };
1078 }
1079
1080 match variant_fields {
1081 syn::Fields::Named(fields) => {
1082 errors.err(
1083 fields,
1084 with_enum_suggestion!(
1085 "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
1086 ),
1087 );
1088 None
1089 }
1090 syn::Fields::Unit => {
1091 errors.err(
1092 variant_fields,
1093 with_enum_suggestion!(
1094 "`#![derive(FromArgs)]` does not support `enum`s with no variants."
1095 ),
1096 );
1097 None
1098 }
1099 syn::Fields::Unnamed(fields) => {
1100 if fields.unnamed.len() != 1 {
1101 errors.err(
1102 fields,
1103 with_enum_suggestion!(
1104 "`#![derive(FromArgs)]` `enum` variants must only contain one field."
1105 ),
1106 );
1107 None
1108 } else {
1109 // `unwrap` is okay because of the length check above.
1110 let first_field = fields.unnamed.first().unwrap();
1111 Some(&first_field.ty)
1112 }
1113 }
1114 }
1115 }
1116