xref: /aosp_15_r20/external/bazelbuild-rules_rust/util/process_wrapper/flags.rs (revision d4726bddaa87cc4778e7472feed243fa4b6c267f)
1 // Copyright 2020 The Bazel Authors. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::collections::{BTreeMap, HashSet};
16 use std::error::Error;
17 use std::fmt;
18 use std::fmt::Write;
19 use std::iter::Peekable;
20 use std::mem::take;
21 
22 #[derive(Debug, Clone)]
23 pub(crate) enum FlagParseError {
24     UnknownFlag(String),
25     ValueMissing(String),
26     ProvidedMultipleTimes(String),
27     ProgramNameMissing,
28 }
29 
30 impl fmt::Display for FlagParseError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result31     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
32         match self {
33             Self::UnknownFlag(ref flag) => write!(f, "unknown flag \"{flag}\""),
34             Self::ValueMissing(ref flag) => write!(f, "flag \"{flag}\" missing parameter(s)"),
35             Self::ProvidedMultipleTimes(ref flag) => {
36                 write!(f, "flag \"{flag}\" can only appear once")
37             }
38             Self::ProgramNameMissing => {
39                 write!(f, "program name (argv[0]) missing")
40             }
41         }
42     }
43 }
44 impl Error for FlagParseError {}
45 
46 struct FlagDef<'a, T> {
47     name: String,
48     help: String,
49     output_storage: &'a mut Option<T>,
50 }
51 
52 impl<'a, T> fmt::Display for FlagDef<'a, T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result53     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54         write!(f, "{}\t{}", self.name, self.help)
55     }
56 }
57 
58 impl<'a, T> fmt::Debug for FlagDef<'a, T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result59     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60         f.debug_struct("FlagDef")
61             .field("name", &self.name)
62             .field("help", &self.help)
63             .finish()
64     }
65 }
66 
67 #[derive(Debug)]
68 pub(crate) struct Flags<'a> {
69     single: BTreeMap<String, FlagDef<'a, String>>,
70     repeated: BTreeMap<String, FlagDef<'a, Vec<String>>>,
71 }
72 
73 #[derive(Debug)]
74 pub(crate) enum ParseOutcome {
75     Help(String),
76     Parsed(Vec<String>),
77 }
78 
79 impl<'a> Flags<'a> {
new() -> Flags<'a>80     pub(crate) fn new() -> Flags<'a> {
81         Flags {
82             single: BTreeMap::new(),
83             repeated: BTreeMap::new(),
84         }
85     }
86 
define_flag( &mut self, name: impl Into<String>, help: impl Into<String>, output_storage: &'a mut Option<String>, )87     pub(crate) fn define_flag(
88         &mut self,
89         name: impl Into<String>,
90         help: impl Into<String>,
91         output_storage: &'a mut Option<String>,
92     ) {
93         let name = name.into();
94         if self.repeated.contains_key(&name) {
95             panic!("argument \"{}\" already defined as repeated flag", name)
96         }
97         self.single.insert(
98             name.clone(),
99             FlagDef::<'a, String> {
100                 name,
101                 help: help.into(),
102                 output_storage,
103             },
104         );
105     }
106 
define_repeated_flag( &mut self, name: impl Into<String>, help: impl Into<String>, output_storage: &'a mut Option<Vec<String>>, )107     pub(crate) fn define_repeated_flag(
108         &mut self,
109         name: impl Into<String>,
110         help: impl Into<String>,
111         output_storage: &'a mut Option<Vec<String>>,
112     ) {
113         let name = name.into();
114         if self.single.contains_key(&name) {
115             panic!("argument \"{}\" already defined as flag", name)
116         }
117         self.repeated.insert(
118             name.clone(),
119             FlagDef::<'a, Vec<String>> {
120                 name,
121                 help: help.into(),
122                 output_storage,
123             },
124         );
125     }
126 
help(&self, program_name: String) -> String127     fn help(&self, program_name: String) -> String {
128         let single = self.single.values().map(|fd| fd.to_string());
129         let repeated = self.repeated.values().map(|fd| fd.to_string());
130         let mut all: Vec<String> = single.chain(repeated).collect();
131         all.sort();
132 
133         let mut help_text = String::new();
134         writeln!(
135             &mut help_text,
136             "Help for {program_name}: [options] -- [extra arguments]"
137         )
138         .unwrap();
139         for line in all {
140             writeln!(&mut help_text, "\t{line}").unwrap();
141         }
142         help_text
143     }
144 
parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError>145     pub(crate) fn parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError> {
146         let mut argv_iter = argv.into_iter().peekable();
147         let program_name = argv_iter.next().ok_or(FlagParseError::ProgramNameMissing)?;
148 
149         // To check if a non-repeated flag has been set already.
150         let mut seen_single_flags = HashSet::<String>::new();
151 
152         while let Some(flag) = argv_iter.next() {
153             if flag == "--help" {
154                 return Ok(ParseOutcome::Help(self.help(program_name)));
155             }
156             if !flag.starts_with("--") {
157                 return Err(FlagParseError::UnknownFlag(flag));
158             }
159             let mut args = consume_args(&flag, &mut argv_iter);
160             if flag == "--" {
161                 return Ok(ParseOutcome::Parsed(args));
162             }
163             if args.is_empty() {
164                 return Err(FlagParseError::ValueMissing(flag.clone()));
165             }
166             if let Some(flag_def) = self.single.get_mut(&flag) {
167                 if args.len() > 1 || seen_single_flags.contains(&flag) {
168                     return Err(FlagParseError::ProvidedMultipleTimes(flag.clone()));
169                 }
170                 let arg = args.first_mut().unwrap();
171                 seen_single_flags.insert(flag);
172                 *flag_def.output_storage = Some(take(arg));
173                 continue;
174             }
175             if let Some(flag_def) = self.repeated.get_mut(&flag) {
176                 flag_def
177                     .output_storage
178                     .get_or_insert_with(Vec::new)
179                     .append(&mut args);
180                 continue;
181             }
182             return Err(FlagParseError::UnknownFlag(flag));
183         }
184         Ok(ParseOutcome::Parsed(vec![]))
185     }
186 }
187 
consume_args<I: Iterator<Item = String>>( flag: &str, argv_iter: &mut Peekable<I>, ) -> Vec<String>188 fn consume_args<I: Iterator<Item = String>>(
189     flag: &str,
190     argv_iter: &mut Peekable<I>,
191 ) -> Vec<String> {
192     if flag == "--" {
193         // If we have found --, the rest of the iterator is just returned as-is.
194         argv_iter.collect()
195     } else {
196         let mut args = vec![];
197         while let Some(arg) = argv_iter.next_if(|s| !s.starts_with("--")) {
198             args.push(arg);
199         }
200         args
201     }
202 }
203 
204 #[cfg(test)]
205 mod test {
206     use super::*;
207 
args(args: &[&str]) -> Vec<String>208     fn args(args: &[&str]) -> Vec<String> {
209         ["foo"].iter().chain(args).map(|&s| s.to_owned()).collect()
210     }
211 
212     #[test]
test_flag_help()213     fn test_flag_help() {
214         let mut bar = None;
215         let mut parser = Flags::new();
216         parser.define_flag("--bar", "bar help", &mut bar);
217         let result = parser.parse(args(&["--help"])).unwrap();
218         if let ParseOutcome::Help(h) = result {
219             assert!(h.contains("Help for foo"));
220             assert!(h.contains("--bar\tbar help"));
221         } else {
222             panic!("expected that --help would invoke help, instead parsed arguments")
223         }
224     }
225 
226     #[test]
test_flag_single_repeated()227     fn test_flag_single_repeated() {
228         let mut bar = None;
229         let mut parser = Flags::new();
230         parser.define_flag("--bar", "bar help", &mut bar);
231         let result = parser.parse(args(&["--bar", "aa", "bb"]));
232         if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
233             assert_eq!(f, "--bar");
234         } else {
235             panic!("expected error, got {:?}", result)
236         }
237         let mut parser = Flags::new();
238         parser.define_flag("--bar", "bar help", &mut bar);
239         let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"]));
240         if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
241             assert_eq!(f, "--bar");
242         } else {
243             panic!("expected error, got {:?}", result)
244         }
245     }
246 
247     #[test]
test_repeated_flags()248     fn test_repeated_flags() {
249         // Test case 1) --bar something something_else should work as a repeated flag.
250         let mut bar = None;
251         let mut parser = Flags::new();
252         parser.define_repeated_flag("--bar", "bar help", &mut bar);
253         let result = parser.parse(args(&["--bar", "aa", "bb"])).unwrap();
254         assert!(matches!(result, ParseOutcome::Parsed(_)));
255         assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
256         // Test case 2) --bar something --bar something_else should also work as a repeated flag.
257         bar = None;
258         let mut parser = Flags::new();
259         parser.define_repeated_flag("--bar", "bar help", &mut bar);
260         let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"])).unwrap();
261         assert!(matches!(result, ParseOutcome::Parsed(_)));
262         assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
263     }
264 
265     #[test]
test_extra_args()266     fn test_extra_args() {
267         let parser = Flags::new();
268         let result = parser.parse(args(&["--", "bb"])).unwrap();
269         if let ParseOutcome::Parsed(got) = result {
270             assert_eq!(got, vec!["bb".to_owned()])
271         } else {
272             panic!("expected correct parsing, got {:?}", result)
273         }
274     }
275 }
276