1 // Copyright 2020 The Bazel Authors. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::collections::{BTreeMap, HashSet};
16 use std::error::Error;
17 use std::fmt;
18 use std::fmt::Write;
19 use std::iter::Peekable;
20 use std::mem::take;
21
22 #[derive(Debug, Clone)]
23 pub(crate) enum FlagParseError {
24 UnknownFlag(String),
25 ValueMissing(String),
26 ProvidedMultipleTimes(String),
27 ProgramNameMissing,
28 }
29
30 impl fmt::Display for FlagParseError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result31 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
32 match self {
33 Self::UnknownFlag(ref flag) => write!(f, "unknown flag \"{flag}\""),
34 Self::ValueMissing(ref flag) => write!(f, "flag \"{flag}\" missing parameter(s)"),
35 Self::ProvidedMultipleTimes(ref flag) => {
36 write!(f, "flag \"{flag}\" can only appear once")
37 }
38 Self::ProgramNameMissing => {
39 write!(f, "program name (argv[0]) missing")
40 }
41 }
42 }
43 }
44 impl Error for FlagParseError {}
45
46 struct FlagDef<'a, T> {
47 name: String,
48 help: String,
49 output_storage: &'a mut Option<T>,
50 }
51
52 impl<'a, T> fmt::Display for FlagDef<'a, T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 write!(f, "{}\t{}", self.name, self.help)
55 }
56 }
57
58 impl<'a, T> fmt::Debug for FlagDef<'a, T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 f.debug_struct("FlagDef")
61 .field("name", &self.name)
62 .field("help", &self.help)
63 .finish()
64 }
65 }
66
67 #[derive(Debug)]
68 pub(crate) struct Flags<'a> {
69 single: BTreeMap<String, FlagDef<'a, String>>,
70 repeated: BTreeMap<String, FlagDef<'a, Vec<String>>>,
71 }
72
73 #[derive(Debug)]
74 pub(crate) enum ParseOutcome {
75 Help(String),
76 Parsed(Vec<String>),
77 }
78
79 impl<'a> Flags<'a> {
new() -> Flags<'a>80 pub(crate) fn new() -> Flags<'a> {
81 Flags {
82 single: BTreeMap::new(),
83 repeated: BTreeMap::new(),
84 }
85 }
86
define_flag( &mut self, name: impl Into<String>, help: impl Into<String>, output_storage: &'a mut Option<String>, )87 pub(crate) fn define_flag(
88 &mut self,
89 name: impl Into<String>,
90 help: impl Into<String>,
91 output_storage: &'a mut Option<String>,
92 ) {
93 let name = name.into();
94 if self.repeated.contains_key(&name) {
95 panic!("argument \"{}\" already defined as repeated flag", name)
96 }
97 self.single.insert(
98 name.clone(),
99 FlagDef::<'a, String> {
100 name,
101 help: help.into(),
102 output_storage,
103 },
104 );
105 }
106
define_repeated_flag( &mut self, name: impl Into<String>, help: impl Into<String>, output_storage: &'a mut Option<Vec<String>>, )107 pub(crate) fn define_repeated_flag(
108 &mut self,
109 name: impl Into<String>,
110 help: impl Into<String>,
111 output_storage: &'a mut Option<Vec<String>>,
112 ) {
113 let name = name.into();
114 if self.single.contains_key(&name) {
115 panic!("argument \"{}\" already defined as flag", name)
116 }
117 self.repeated.insert(
118 name.clone(),
119 FlagDef::<'a, Vec<String>> {
120 name,
121 help: help.into(),
122 output_storage,
123 },
124 );
125 }
126
help(&self, program_name: String) -> String127 fn help(&self, program_name: String) -> String {
128 let single = self.single.values().map(|fd| fd.to_string());
129 let repeated = self.repeated.values().map(|fd| fd.to_string());
130 let mut all: Vec<String> = single.chain(repeated).collect();
131 all.sort();
132
133 let mut help_text = String::new();
134 writeln!(
135 &mut help_text,
136 "Help for {program_name}: [options] -- [extra arguments]"
137 )
138 .unwrap();
139 for line in all {
140 writeln!(&mut help_text, "\t{line}").unwrap();
141 }
142 help_text
143 }
144
parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError>145 pub(crate) fn parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError> {
146 let mut argv_iter = argv.into_iter().peekable();
147 let program_name = argv_iter.next().ok_or(FlagParseError::ProgramNameMissing)?;
148
149 // To check if a non-repeated flag has been set already.
150 let mut seen_single_flags = HashSet::<String>::new();
151
152 while let Some(flag) = argv_iter.next() {
153 if flag == "--help" {
154 return Ok(ParseOutcome::Help(self.help(program_name)));
155 }
156 if !flag.starts_with("--") {
157 return Err(FlagParseError::UnknownFlag(flag));
158 }
159 let mut args = consume_args(&flag, &mut argv_iter);
160 if flag == "--" {
161 return Ok(ParseOutcome::Parsed(args));
162 }
163 if args.is_empty() {
164 return Err(FlagParseError::ValueMissing(flag.clone()));
165 }
166 if let Some(flag_def) = self.single.get_mut(&flag) {
167 if args.len() > 1 || seen_single_flags.contains(&flag) {
168 return Err(FlagParseError::ProvidedMultipleTimes(flag.clone()));
169 }
170 let arg = args.first_mut().unwrap();
171 seen_single_flags.insert(flag);
172 *flag_def.output_storage = Some(take(arg));
173 continue;
174 }
175 if let Some(flag_def) = self.repeated.get_mut(&flag) {
176 flag_def
177 .output_storage
178 .get_or_insert_with(Vec::new)
179 .append(&mut args);
180 continue;
181 }
182 return Err(FlagParseError::UnknownFlag(flag));
183 }
184 Ok(ParseOutcome::Parsed(vec![]))
185 }
186 }
187
consume_args<I: Iterator<Item = String>>( flag: &str, argv_iter: &mut Peekable<I>, ) -> Vec<String>188 fn consume_args<I: Iterator<Item = String>>(
189 flag: &str,
190 argv_iter: &mut Peekable<I>,
191 ) -> Vec<String> {
192 if flag == "--" {
193 // If we have found --, the rest of the iterator is just returned as-is.
194 argv_iter.collect()
195 } else {
196 let mut args = vec![];
197 while let Some(arg) = argv_iter.next_if(|s| !s.starts_with("--")) {
198 args.push(arg);
199 }
200 args
201 }
202 }
203
204 #[cfg(test)]
205 mod test {
206 use super::*;
207
args(args: &[&str]) -> Vec<String>208 fn args(args: &[&str]) -> Vec<String> {
209 ["foo"].iter().chain(args).map(|&s| s.to_owned()).collect()
210 }
211
212 #[test]
test_flag_help()213 fn test_flag_help() {
214 let mut bar = None;
215 let mut parser = Flags::new();
216 parser.define_flag("--bar", "bar help", &mut bar);
217 let result = parser.parse(args(&["--help"])).unwrap();
218 if let ParseOutcome::Help(h) = result {
219 assert!(h.contains("Help for foo"));
220 assert!(h.contains("--bar\tbar help"));
221 } else {
222 panic!("expected that --help would invoke help, instead parsed arguments")
223 }
224 }
225
226 #[test]
test_flag_single_repeated()227 fn test_flag_single_repeated() {
228 let mut bar = None;
229 let mut parser = Flags::new();
230 parser.define_flag("--bar", "bar help", &mut bar);
231 let result = parser.parse(args(&["--bar", "aa", "bb"]));
232 if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
233 assert_eq!(f, "--bar");
234 } else {
235 panic!("expected error, got {:?}", result)
236 }
237 let mut parser = Flags::new();
238 parser.define_flag("--bar", "bar help", &mut bar);
239 let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"]));
240 if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
241 assert_eq!(f, "--bar");
242 } else {
243 panic!("expected error, got {:?}", result)
244 }
245 }
246
247 #[test]
test_repeated_flags()248 fn test_repeated_flags() {
249 // Test case 1) --bar something something_else should work as a repeated flag.
250 let mut bar = None;
251 let mut parser = Flags::new();
252 parser.define_repeated_flag("--bar", "bar help", &mut bar);
253 let result = parser.parse(args(&["--bar", "aa", "bb"])).unwrap();
254 assert!(matches!(result, ParseOutcome::Parsed(_)));
255 assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
256 // Test case 2) --bar something --bar something_else should also work as a repeated flag.
257 bar = None;
258 let mut parser = Flags::new();
259 parser.define_repeated_flag("--bar", "bar help", &mut bar);
260 let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"])).unwrap();
261 assert!(matches!(result, ParseOutcome::Parsed(_)));
262 assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
263 }
264
265 #[test]
test_extra_args()266 fn test_extra_args() {
267 let parser = Flags::new();
268 let result = parser.parse(args(&["--", "bb"])).unwrap();
269 if let ParseOutcome::Parsed(got) = result {
270 assert_eq!(got, vec!["bb".to_owned()])
271 } else {
272 panic!("expected correct parsing, got {:?}", result)
273 }
274 }
275 }
276