1 // Internal
2 use crate::builder::StyledStr;
3 use crate::builder::{Arg, ArgGroup, ArgPredicate, Command, PossibleValue};
4 use crate::error::{Error, Result as ClapResult};
5 use crate::output::Usage;
6 use crate::parser::{ArgMatcher, ParseState};
7 use crate::util::ChildGraph;
8 use crate::util::FlatMap;
9 use crate::util::FlatSet;
10 use crate::util::Id;
11 use crate::INTERNAL_ERROR_MSG;
12 
13 pub(crate) struct Validator<'cmd> {
14     cmd: &'cmd Command,
15     required: ChildGraph<Id>,
16 }
17 
18 impl<'cmd> Validator<'cmd> {
new(cmd: &'cmd Command) -> Self19     pub(crate) fn new(cmd: &'cmd Command) -> Self {
20         let required = cmd.required_graph();
21         Validator { cmd, required }
22     }
23 
validate( &mut self, parse_state: ParseState, matcher: &mut ArgMatcher, ) -> ClapResult<()>24     pub(crate) fn validate(
25         &mut self,
26         parse_state: ParseState,
27         matcher: &mut ArgMatcher,
28     ) -> ClapResult<()> {
29         debug!("Validator::validate");
30         let conflicts = Conflicts::with_args(self.cmd, matcher);
31         let has_subcmd = matcher.subcommand_name().is_some();
32 
33         if let ParseState::Opt(a) = parse_state {
34             debug!("Validator::validate: needs_val_of={a:?}");
35 
36             let o = &self.cmd[&a];
37             let should_err = if let Some(v) = matcher.args.get(o.get_id()) {
38                 v.all_val_groups_empty() && o.get_min_vals() != 0
39             } else {
40                 true
41             };
42             if should_err {
43                 return Err(Error::empty_value(
44                     self.cmd,
45                     &get_possible_values_cli(o)
46                         .iter()
47                         .filter(|pv| !pv.is_hide_set())
48                         .map(|n| n.get_name().to_owned())
49                         .collect::<Vec<_>>(),
50                     o.to_string(),
51                 ));
52             }
53         }
54 
55         if !has_subcmd && self.cmd.is_arg_required_else_help_set() {
56             let num_user_values = matcher
57                 .args()
58                 .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
59                 .count();
60             if num_user_values == 0 {
61                 let message = self.cmd.write_help_err(false);
62                 return Err(Error::display_help_error(self.cmd, message));
63             }
64         }
65         if !has_subcmd && self.cmd.is_subcommand_required_set() {
66             let bn = self.cmd.get_bin_name_fallback();
67             return Err(Error::missing_subcommand(
68                 self.cmd,
69                 bn.to_string(),
70                 self.cmd
71                     .all_subcommand_names()
72                     .map(|s| s.to_owned())
73                     .collect::<Vec<_>>(),
74                 Usage::new(self.cmd)
75                     .required(&self.required)
76                     .create_usage_with_title(&[]),
77             ));
78         }
79 
80         ok!(self.validate_conflicts(matcher, &conflicts));
81         if !(self.cmd.is_subcommand_negates_reqs_set() && has_subcmd) {
82             ok!(self.validate_required(matcher, &conflicts));
83         }
84 
85         Ok(())
86     }
87 
validate_conflicts( &mut self, matcher: &ArgMatcher, conflicts: &Conflicts, ) -> ClapResult<()>88     fn validate_conflicts(
89         &mut self,
90         matcher: &ArgMatcher,
91         conflicts: &Conflicts,
92     ) -> ClapResult<()> {
93         debug!("Validator::validate_conflicts");
94 
95         ok!(self.validate_exclusive(matcher));
96 
97         for (arg_id, _) in matcher
98             .args()
99             .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
100             .filter(|(arg_id, _)| self.cmd.find(arg_id).is_some())
101         {
102             debug!("Validator::validate_conflicts::iter: id={arg_id:?}");
103             let conflicts = conflicts.gather_conflicts(self.cmd, arg_id);
104             ok!(self.build_conflict_err(arg_id, &conflicts, matcher));
105         }
106 
107         Ok(())
108     }
109 
validate_exclusive(&self, matcher: &ArgMatcher) -> ClapResult<()>110     fn validate_exclusive(&self, matcher: &ArgMatcher) -> ClapResult<()> {
111         debug!("Validator::validate_exclusive");
112         let args_count = matcher
113             .args()
114             .filter(|(arg_id, matched)| {
115                 matched.check_explicit(&crate::builder::ArgPredicate::IsPresent)
116                     // Avoid including our own groups by checking none of them.  If a group is present, the
117                     // args for the group will be.
118                     && self.cmd.find(arg_id).is_some()
119             })
120             .count();
121         if args_count <= 1 {
122             // Nothing present to conflict with
123             return Ok(());
124         }
125 
126         matcher
127             .args()
128             .filter(|(_, matched)| matched.check_explicit(&crate::builder::ArgPredicate::IsPresent))
129             .filter_map(|(id, _)| {
130                 debug!("Validator::validate_exclusive:iter:{id:?}");
131                 self.cmd
132                     .find(id)
133                     // Find `arg`s which are exclusive but also appear with other args.
134                     .filter(|&arg| arg.is_exclusive_set() && args_count > 1)
135             })
136             .next()
137             .map(|arg| {
138                 // Throw an error for the first conflict found.
139                 Err(Error::argument_conflict(
140                     self.cmd,
141                     arg.to_string(),
142                     Vec::new(),
143                     Usage::new(self.cmd)
144                         .required(&self.required)
145                         .create_usage_with_title(&[]),
146                 ))
147             })
148             .unwrap_or(Ok(()))
149     }
150 
build_conflict_err( &self, name: &Id, conflict_ids: &[Id], matcher: &ArgMatcher, ) -> ClapResult<()>151     fn build_conflict_err(
152         &self,
153         name: &Id,
154         conflict_ids: &[Id],
155         matcher: &ArgMatcher,
156     ) -> ClapResult<()> {
157         if conflict_ids.is_empty() {
158             return Ok(());
159         }
160 
161         debug!("Validator::build_conflict_err: name={name:?}");
162         let mut seen = FlatSet::new();
163         let conflicts = conflict_ids
164             .iter()
165             .flat_map(|c_id| {
166                 if self.cmd.find_group(c_id).is_some() {
167                     self.cmd.unroll_args_in_group(c_id)
168                 } else {
169                     vec![c_id.clone()]
170                 }
171             })
172             .filter_map(|c_id| {
173                 seen.insert(c_id.clone()).then(|| {
174                     let c_arg = self.cmd.find(&c_id).expect(INTERNAL_ERROR_MSG);
175                     c_arg.to_string()
176                 })
177             })
178             .collect();
179 
180         let former_arg = self.cmd.find(name).expect(INTERNAL_ERROR_MSG);
181         let usg = self.build_conflict_err_usage(matcher, conflict_ids);
182         Err(Error::argument_conflict(
183             self.cmd,
184             former_arg.to_string(),
185             conflicts,
186             usg,
187         ))
188     }
189 
build_conflict_err_usage( &self, matcher: &ArgMatcher, conflicting_keys: &[Id], ) -> Option<StyledStr>190     fn build_conflict_err_usage(
191         &self,
192         matcher: &ArgMatcher,
193         conflicting_keys: &[Id],
194     ) -> Option<StyledStr> {
195         let used_filtered: Vec<Id> = matcher
196             .args()
197             .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
198             .map(|(n, _)| n)
199             .filter(|n| {
200                 // Filter out the args we don't want to specify.
201                 self.cmd
202                     .find(n)
203                     .map(|a| !a.is_hide_set())
204                     .unwrap_or_default()
205             })
206             .filter(|key| !conflicting_keys.contains(key))
207             .cloned()
208             .collect();
209         let required: Vec<Id> = used_filtered
210             .iter()
211             .filter_map(|key| self.cmd.find(key))
212             .flat_map(|arg| arg.requires.iter().map(|item| &item.1))
213             .filter(|key| !used_filtered.contains(key) && !conflicting_keys.contains(key))
214             .chain(used_filtered.iter())
215             .cloned()
216             .collect();
217         Usage::new(self.cmd)
218             .required(&self.required)
219             .create_usage_with_title(&required)
220     }
221 
gather_requires(&mut self, matcher: &ArgMatcher)222     fn gather_requires(&mut self, matcher: &ArgMatcher) {
223         debug!("Validator::gather_requires");
224         for (name, matched) in matcher
225             .args()
226             .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
227         {
228             debug!("Validator::gather_requires:iter:{name:?}");
229             if let Some(arg) = self.cmd.find(name) {
230                 let is_relevant = |(val, req_arg): &(ArgPredicate, Id)| -> Option<Id> {
231                     let required = matched.check_explicit(val);
232                     required.then(|| req_arg.clone())
233                 };
234 
235                 for req in self.cmd.unroll_arg_requires(is_relevant, arg.get_id()) {
236                     self.required.insert(req);
237                 }
238             } else if let Some(g) = self.cmd.find_group(name) {
239                 debug!("Validator::gather_requires:iter:{name:?}:group");
240                 for r in &g.requires {
241                     self.required.insert(r.clone());
242                 }
243             }
244         }
245     }
246 
validate_required(&mut self, matcher: &ArgMatcher, conflicts: &Conflicts) -> ClapResult<()>247     fn validate_required(&mut self, matcher: &ArgMatcher, conflicts: &Conflicts) -> ClapResult<()> {
248         debug!("Validator::validate_required: required={:?}", self.required);
249         self.gather_requires(matcher);
250 
251         let mut missing_required = Vec::new();
252         let mut highest_index = 0;
253 
254         let is_exclusive_present = matcher
255             .args()
256             .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
257             .any(|(id, _)| {
258                 self.cmd
259                     .find(id)
260                     .map(|arg| arg.is_exclusive_set())
261                     .unwrap_or_default()
262             });
263         debug!("Validator::validate_required: is_exclusive_present={is_exclusive_present}");
264 
265         for arg_or_group in self
266             .required
267             .iter()
268             .filter(|r| !matcher.check_explicit(r, &ArgPredicate::IsPresent))
269         {
270             debug!("Validator::validate_required:iter:aog={arg_or_group:?}");
271             if let Some(arg) = self.cmd.find(arg_or_group) {
272                 debug!("Validator::validate_required:iter: This is an arg");
273                 if !is_exclusive_present && !self.is_missing_required_ok(arg, conflicts) {
274                     debug!(
275                         "Validator::validate_required:iter: Missing {:?}",
276                         arg.get_id()
277                     );
278                     missing_required.push(arg.get_id().clone());
279                     if !arg.is_last_set() {
280                         highest_index = highest_index.max(arg.get_index().unwrap_or(0));
281                     }
282                 }
283             } else if let Some(group) = self.cmd.find_group(arg_or_group) {
284                 debug!("Validator::validate_required:iter: This is a group");
285                 if !self
286                     .cmd
287                     .unroll_args_in_group(&group.id)
288                     .iter()
289                     .any(|a| matcher.check_explicit(a, &ArgPredicate::IsPresent))
290                 {
291                     debug!(
292                         "Validator::validate_required:iter: Missing {:?}",
293                         group.get_id()
294                     );
295                     missing_required.push(group.get_id().clone());
296                 }
297             }
298         }
299 
300         // Validate the conditionally required args
301         for a in self
302             .cmd
303             .get_arguments()
304             .filter(|a| !matcher.check_explicit(a.get_id(), &ArgPredicate::IsPresent))
305         {
306             let mut required = false;
307 
308             for (other, val) in &a.r_ifs {
309                 if matcher.check_explicit(other, &ArgPredicate::Equals(val.into())) {
310                     debug!(
311                         "Validator::validate_required:iter: Missing {:?}",
312                         a.get_id()
313                     );
314                     required = true;
315                 }
316             }
317 
318             let match_all = a.r_ifs_all.iter().all(|(other, val)| {
319                 matcher.check_explicit(other, &ArgPredicate::Equals(val.into()))
320             });
321             if match_all && !a.r_ifs_all.is_empty() {
322                 debug!(
323                     "Validator::validate_required:iter: Missing {:?}",
324                     a.get_id()
325                 );
326                 required = true;
327             }
328 
329             if (!a.r_unless.is_empty() || !a.r_unless_all.is_empty())
330                 && self.fails_arg_required_unless(a, matcher)
331             {
332                 debug!(
333                     "Validator::validate_required:iter: Missing {:?}",
334                     a.get_id()
335                 );
336                 required = true;
337             }
338 
339             if required {
340                 missing_required.push(a.get_id().clone());
341                 if !a.is_last_set() {
342                     highest_index = highest_index.max(a.get_index().unwrap_or(0));
343                 }
344             }
345         }
346 
347         // For display purposes, include all of the preceding positional arguments
348         if !self.cmd.is_allow_missing_positional_set() {
349             for pos in self
350                 .cmd
351                 .get_positionals()
352                 .filter(|a| !matcher.check_explicit(a.get_id(), &ArgPredicate::IsPresent))
353             {
354                 if pos.get_index() < Some(highest_index) {
355                     debug!(
356                         "Validator::validate_required:iter: Missing {:?}",
357                         pos.get_id()
358                     );
359                     missing_required.push(pos.get_id().clone());
360                 }
361             }
362         }
363 
364         if !missing_required.is_empty() {
365             ok!(self.missing_required_error(matcher, missing_required));
366         }
367 
368         Ok(())
369     }
370 
is_missing_required_ok(&self, a: &Arg, conflicts: &Conflicts) -> bool371     fn is_missing_required_ok(&self, a: &Arg, conflicts: &Conflicts) -> bool {
372         debug!("Validator::is_missing_required_ok: {}", a.get_id());
373         if !conflicts.gather_conflicts(self.cmd, a.get_id()).is_empty() {
374             debug!("Validator::is_missing_required_ok: true (self)");
375             return true;
376         }
377         for group_id in self.cmd.groups_for_arg(a.get_id()) {
378             if !conflicts.gather_conflicts(self.cmd, &group_id).is_empty() {
379                 debug!("Validator::is_missing_required_ok: true ({group_id})");
380                 return true;
381             }
382         }
383         false
384     }
385 
386     // Failing a required unless means, the arg's "unless" wasn't present, and neither were they
fails_arg_required_unless(&self, a: &Arg, matcher: &ArgMatcher) -> bool387     fn fails_arg_required_unless(&self, a: &Arg, matcher: &ArgMatcher) -> bool {
388         debug!("Validator::fails_arg_required_unless: a={:?}", a.get_id());
389         let exists = |id| matcher.check_explicit(id, &ArgPredicate::IsPresent);
390 
391         (a.r_unless_all.is_empty() || !a.r_unless_all.iter().all(exists))
392             && !a.r_unless.iter().any(exists)
393     }
394 
395     // `req_args`: an arg to include in the error even if not used
missing_required_error( &self, matcher: &ArgMatcher, raw_req_args: Vec<Id>, ) -> ClapResult<()>396     fn missing_required_error(
397         &self,
398         matcher: &ArgMatcher,
399         raw_req_args: Vec<Id>,
400     ) -> ClapResult<()> {
401         debug!("Validator::missing_required_error; incl={raw_req_args:?}");
402         debug!(
403             "Validator::missing_required_error: reqs={:?}",
404             self.required
405         );
406 
407         let usg = Usage::new(self.cmd).required(&self.required);
408 
409         let req_args = {
410             #[cfg(feature = "usage")]
411             {
412                 usg.get_required_usage_from(&raw_req_args, Some(matcher), true)
413                     .into_iter()
414                     .map(|s| s.to_string())
415                     .collect::<Vec<_>>()
416             }
417 
418             #[cfg(not(feature = "usage"))]
419             {
420                 raw_req_args
421                     .iter()
422                     .map(|id| {
423                         if let Some(arg) = self.cmd.find(id) {
424                             arg.to_string()
425                         } else if let Some(_group) = self.cmd.find_group(id) {
426                             self.cmd.format_group(id).to_string()
427                         } else {
428                             debug_assert!(false, "id={id:?} is unknown");
429                             "".to_owned()
430                         }
431                     })
432                     .collect::<FlatSet<_>>()
433                     .into_iter()
434                     .collect::<Vec<_>>()
435             }
436         };
437 
438         debug!("Validator::missing_required_error: req_args={req_args:#?}");
439 
440         let used: Vec<Id> = matcher
441             .args()
442             .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
443             .map(|(n, _)| n)
444             .filter(|n| {
445                 // Filter out the args we don't want to specify.
446                 self.cmd
447                     .find(n)
448                     .map(|a| !a.is_hide_set())
449                     .unwrap_or_default()
450             })
451             .cloned()
452             .chain(raw_req_args)
453             .collect();
454 
455         Err(Error::missing_required_argument(
456             self.cmd,
457             req_args,
458             usg.create_usage_with_title(&used),
459         ))
460     }
461 }
462 
463 #[derive(Default, Clone, Debug)]
464 struct Conflicts {
465     potential: FlatMap<Id, Vec<Id>>,
466 }
467 
468 impl Conflicts {
with_args(cmd: &Command, matcher: &ArgMatcher) -> Self469     fn with_args(cmd: &Command, matcher: &ArgMatcher) -> Self {
470         let mut potential = FlatMap::new();
471         potential.extend_unchecked(
472             matcher
473                 .args()
474                 .filter(|(_, matched)| matched.check_explicit(&ArgPredicate::IsPresent))
475                 .map(|(id, _)| {
476                     let conf = gather_direct_conflicts(cmd, id);
477                     (id.clone(), conf)
478                 }),
479         );
480         Self { potential }
481     }
482 
gather_conflicts(&self, cmd: &Command, arg_id: &Id) -> Vec<Id>483     fn gather_conflicts(&self, cmd: &Command, arg_id: &Id) -> Vec<Id> {
484         debug!("Conflicts::gather_conflicts: arg={arg_id:?}");
485         let mut conflicts = Vec::new();
486 
487         let arg_id_conflicts_storage;
488         let arg_id_conflicts = if let Some(arg_id_conflicts) = self.get_direct_conflicts(arg_id) {
489             arg_id_conflicts
490         } else {
491             // `is_missing_required_ok` is a case where we check not-present args for conflicts
492             arg_id_conflicts_storage = gather_direct_conflicts(cmd, arg_id);
493             &arg_id_conflicts_storage
494         };
495         for (other_arg_id, other_arg_id_conflicts) in self.potential.iter() {
496             if arg_id == other_arg_id {
497                 continue;
498             }
499 
500             if arg_id_conflicts.contains(other_arg_id) {
501                 conflicts.push(other_arg_id.clone());
502             }
503             if other_arg_id_conflicts.contains(arg_id) {
504                 conflicts.push(other_arg_id.clone());
505             }
506         }
507 
508         debug!("Conflicts::gather_conflicts: conflicts={conflicts:?}");
509         conflicts
510     }
511 
get_direct_conflicts(&self, arg_id: &Id) -> Option<&[Id]>512     fn get_direct_conflicts(&self, arg_id: &Id) -> Option<&[Id]> {
513         self.potential.get(arg_id).map(Vec::as_slice)
514     }
515 }
516 
gather_direct_conflicts(cmd: &Command, id: &Id) -> Vec<Id>517 fn gather_direct_conflicts(cmd: &Command, id: &Id) -> Vec<Id> {
518     let conf = if let Some(arg) = cmd.find(id) {
519         gather_arg_direct_conflicts(cmd, arg)
520     } else if let Some(group) = cmd.find_group(id) {
521         gather_group_direct_conflicts(group)
522     } else {
523         debug_assert!(false, "id={id:?} is unknown");
524         Vec::new()
525     };
526     debug!("Conflicts::gather_direct_conflicts id={id:?}, conflicts={conf:?}",);
527     conf
528 }
529 
gather_arg_direct_conflicts(cmd: &Command, arg: &Arg) -> Vec<Id>530 fn gather_arg_direct_conflicts(cmd: &Command, arg: &Arg) -> Vec<Id> {
531     let mut conf = arg.blacklist.clone();
532     for group_id in cmd.groups_for_arg(arg.get_id()) {
533         let group = cmd.find_group(&group_id).expect(INTERNAL_ERROR_MSG);
534         conf.extend(group.conflicts.iter().cloned());
535         if !group.multiple {
536             for member_id in &group.args {
537                 if member_id != arg.get_id() {
538                     conf.push(member_id.clone());
539                 }
540             }
541         }
542     }
543 
544     // Overrides are implicitly conflicts
545     conf.extend(arg.overrides.iter().cloned());
546 
547     conf
548 }
549 
gather_group_direct_conflicts(group: &ArgGroup) -> Vec<Id>550 fn gather_group_direct_conflicts(group: &ArgGroup) -> Vec<Id> {
551     group.conflicts.clone()
552 }
553 
get_possible_values_cli(a: &Arg) -> Vec<PossibleValue>554 pub(crate) fn get_possible_values_cli(a: &Arg) -> Vec<PossibleValue> {
555     if !a.is_takes_value_set() {
556         vec![]
557     } else {
558         a.get_value_parser()
559             .possible_values()
560             .map(|pvs| pvs.collect())
561             .unwrap_or_default()
562     }
563 }
564