1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Contains base classes used to parse and convert arguments. 16 17Do NOT import this module directly. Import the flags package and use the 18aliases defined at the package level instead. 19""" 20 21import collections 22import csv 23import enum 24import io 25import string 26from typing import Generic, List, Iterable, Optional, Sequence, Text, Type, TypeVar, Union 27from xml.dom import minidom 28 29from absl.flags import _helpers 30 31_T = TypeVar('_T') 32_ET = TypeVar('_ET', bound=enum.Enum) 33_N = TypeVar('_N', int, float) 34 35 36def _is_integer_type(instance): 37 """Returns True if instance is an integer, and not a bool.""" 38 return (isinstance(instance, int) and 39 not isinstance(instance, bool)) 40 41 42class _ArgumentParserCache(type): 43 """Metaclass used to cache and share argument parsers among flags.""" 44 45 _instances = {} 46 47 def __call__(cls, *args, **kwargs): 48 """Returns an instance of the argument parser cls. 49 50 This method overrides behavior of the __new__ methods in 51 all subclasses of ArgumentParser (inclusive). If an instance 52 for cls with the same set of arguments exists, this instance is 53 returned, otherwise a new instance is created. 54 55 If any keyword arguments are defined, or the values in args 56 are not hashable, this method always returns a new instance of 57 cls. 58 59 Args: 60 *args: Positional initializer arguments. 61 **kwargs: Initializer keyword arguments. 62 63 Returns: 64 An instance of cls, shared or new. 65 """ 66 if kwargs: 67 return type.__call__(cls, *args, **kwargs) 68 else: 69 instances = cls._instances 70 key = (cls,) + tuple(args) 71 try: 72 return instances[key] 73 except KeyError: 74 # No cache entry for key exists, create a new one. 75 return instances.setdefault(key, type.__call__(cls, *args)) 76 except TypeError: 77 # An object in args cannot be hashed, always return 78 # a new instance. 79 return type.__call__(cls, *args) 80 81 82class ArgumentParser(Generic[_T], metaclass=_ArgumentParserCache): 83 """Base class used to parse and convert arguments. 84 85 The :meth:`parse` method checks to make sure that the string argument is a 86 legal value and convert it to a native type. If the value cannot be 87 converted, it should throw a ``ValueError`` exception with a human 88 readable explanation of why the value is illegal. 89 90 Subclasses should also define a syntactic_help string which may be 91 presented to the user to describe the form of the legal values. 92 93 Argument parser classes must be stateless, since instances are cached 94 and shared between flags. Initializer arguments are allowed, but all 95 member variables must be derived from initializer arguments only. 96 """ 97 98 syntactic_help: Text = '' 99 100 def parse(self, argument: Text) -> Optional[_T]: 101 """Parses the string argument and returns the native value. 102 103 By default it returns its argument unmodified. 104 105 Args: 106 argument: string argument passed in the commandline. 107 108 Raises: 109 ValueError: Raised when it fails to parse the argument. 110 TypeError: Raised when the argument has the wrong type. 111 112 Returns: 113 The parsed value in native type. 114 """ 115 if not isinstance(argument, str): 116 raise TypeError('flag value must be a string, found "{}"'.format( 117 type(argument))) 118 return argument 119 120 def flag_type(self) -> Text: 121 """Returns a string representing the type of the flag.""" 122 return 'string' 123 124 def _custom_xml_dom_elements( 125 self, doc: minidom.Document 126 ) -> List[minidom.Element]: 127 """Returns a list of minidom.Element to add additional flag information. 128 129 Args: 130 doc: minidom.Document, the DOM document it should create nodes from. 131 """ 132 del doc # Unused. 133 return [] 134 135 136class ArgumentSerializer(Generic[_T]): 137 """Base class for generating string representations of a flag value.""" 138 139 def serialize(self, value: _T) -> Text: 140 """Returns a serialized string of the value.""" 141 return str(value) 142 143 144class NumericParser(ArgumentParser[_N]): 145 """Parser of numeric values. 146 147 Parsed value may be bounded to a given upper and lower bound. 148 """ 149 150 lower_bound: Optional[_N] 151 upper_bound: Optional[_N] 152 153 def is_outside_bounds(self, val: _N) -> bool: 154 """Returns whether the value is outside the bounds or not.""" 155 return ((self.lower_bound is not None and val < self.lower_bound) or 156 (self.upper_bound is not None and val > self.upper_bound)) 157 158 def parse(self, argument: Text) -> _N: 159 """See base class.""" 160 val = self.convert(argument) 161 if self.is_outside_bounds(val): 162 raise ValueError('%s is not %s' % (val, self.syntactic_help)) 163 return val 164 165 def _custom_xml_dom_elements( 166 self, doc: minidom.Document 167 ) -> List[minidom.Element]: 168 elements = [] 169 if self.lower_bound is not None: 170 elements.append(_helpers.create_xml_dom_element( 171 doc, 'lower_bound', self.lower_bound)) 172 if self.upper_bound is not None: 173 elements.append(_helpers.create_xml_dom_element( 174 doc, 'upper_bound', self.upper_bound)) 175 return elements 176 177 def convert(self, argument: Text) -> _N: 178 """Returns the correct numeric value of argument. 179 180 Subclass must implement this method, and raise TypeError if argument is not 181 string or has the right numeric type. 182 183 Args: 184 argument: string argument passed in the commandline, or the numeric type. 185 186 Raises: 187 TypeError: Raised when argument is not a string or the right numeric type. 188 ValueError: Raised when failed to convert argument to the numeric value. 189 """ 190 raise NotImplementedError 191 192 193class FloatParser(NumericParser[float]): 194 """Parser of floating point values. 195 196 Parsed value may be bounded to a given upper and lower bound. 197 """ 198 number_article = 'a' 199 number_name = 'number' 200 syntactic_help = ' '.join((number_article, number_name)) 201 202 def __init__( 203 self, 204 lower_bound: Optional[float] = None, 205 upper_bound: Optional[float] = None, 206 ) -> None: 207 super(FloatParser, self).__init__() 208 self.lower_bound = lower_bound 209 self.upper_bound = upper_bound 210 sh = self.syntactic_help 211 if lower_bound is not None and upper_bound is not None: 212 sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound)) 213 elif lower_bound == 0: 214 sh = 'a non-negative %s' % self.number_name 215 elif upper_bound == 0: 216 sh = 'a non-positive %s' % self.number_name 217 elif upper_bound is not None: 218 sh = '%s <= %s' % (self.number_name, upper_bound) 219 elif lower_bound is not None: 220 sh = '%s >= %s' % (self.number_name, lower_bound) 221 self.syntactic_help = sh 222 223 def convert(self, argument: Union[int, float, str]) -> float: 224 """Returns the float value of argument.""" 225 if (_is_integer_type(argument) or isinstance(argument, float) or 226 isinstance(argument, str)): 227 return float(argument) 228 else: 229 raise TypeError( 230 'Expect argument to be a string, int, or float, found {}'.format( 231 type(argument))) 232 233 def flag_type(self) -> Text: 234 """See base class.""" 235 return 'float' 236 237 238class IntegerParser(NumericParser[int]): 239 """Parser of an integer value. 240 241 Parsed value may be bounded to a given upper and lower bound. 242 """ 243 number_article = 'an' 244 number_name = 'integer' 245 syntactic_help = ' '.join((number_article, number_name)) 246 247 def __init__( 248 self, lower_bound: Optional[int] = None, upper_bound: Optional[int] = None 249 ) -> None: 250 super(IntegerParser, self).__init__() 251 self.lower_bound = lower_bound 252 self.upper_bound = upper_bound 253 sh = self.syntactic_help 254 if lower_bound is not None and upper_bound is not None: 255 sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound)) 256 elif lower_bound == 1: 257 sh = 'a positive %s' % self.number_name 258 elif upper_bound == -1: 259 sh = 'a negative %s' % self.number_name 260 elif lower_bound == 0: 261 sh = 'a non-negative %s' % self.number_name 262 elif upper_bound == 0: 263 sh = 'a non-positive %s' % self.number_name 264 elif upper_bound is not None: 265 sh = '%s <= %s' % (self.number_name, upper_bound) 266 elif lower_bound is not None: 267 sh = '%s >= %s' % (self.number_name, lower_bound) 268 self.syntactic_help = sh 269 270 def convert(self, argument: Union[int, Text]) -> int: 271 """Returns the int value of argument.""" 272 if _is_integer_type(argument): 273 return argument 274 elif isinstance(argument, str): 275 base = 10 276 if len(argument) > 2 and argument[0] == '0': 277 if argument[1] == 'o': 278 base = 8 279 elif argument[1] == 'x': 280 base = 16 281 return int(argument, base) 282 else: 283 raise TypeError('Expect argument to be a string or int, found {}'.format( 284 type(argument))) 285 286 def flag_type(self) -> Text: 287 """See base class.""" 288 return 'int' 289 290 291class BooleanParser(ArgumentParser[bool]): 292 """Parser of boolean values.""" 293 294 def parse(self, argument: Union[Text, int]) -> bool: 295 """See base class.""" 296 if isinstance(argument, str): 297 if argument.lower() in ('true', 't', '1'): 298 return True 299 elif argument.lower() in ('false', 'f', '0'): 300 return False 301 else: 302 raise ValueError('Non-boolean argument to boolean flag', argument) 303 elif isinstance(argument, int): 304 # Only allow bool or integer 0, 1. 305 # Note that float 1.0 == True, 0.0 == False. 306 bool_value = bool(argument) 307 if argument == bool_value: 308 return bool_value 309 else: 310 raise ValueError('Non-boolean argument to boolean flag', argument) 311 312 raise TypeError('Non-boolean argument to boolean flag', argument) 313 314 def flag_type(self) -> Text: 315 """See base class.""" 316 return 'bool' 317 318 319class EnumParser(ArgumentParser[Text]): 320 """Parser of a string enum value (a string value from a given set).""" 321 322 def __init__( 323 self, enum_values: Iterable[Text], case_sensitive: bool = True 324 ) -> None: 325 """Initializes EnumParser. 326 327 Args: 328 enum_values: [str], a non-empty list of string values in the enum. 329 case_sensitive: bool, whether or not the enum is to be case-sensitive. 330 331 Raises: 332 ValueError: When enum_values is empty. 333 """ 334 if not enum_values: 335 raise ValueError( 336 'enum_values cannot be empty, found "{}"'.format(enum_values)) 337 if isinstance(enum_values, str): 338 raise ValueError( 339 'enum_values cannot be a str, found "{}"'.format(enum_values) 340 ) 341 super(EnumParser, self).__init__() 342 self.enum_values = list(enum_values) 343 self.case_sensitive = case_sensitive 344 345 def parse(self, argument: Text) -> Text: 346 """Determines validity of argument and returns the correct element of enum. 347 348 Args: 349 argument: str, the supplied flag value. 350 351 Returns: 352 The first matching element from enum_values. 353 354 Raises: 355 ValueError: Raised when argument didn't match anything in enum. 356 """ 357 if self.case_sensitive: 358 if argument not in self.enum_values: 359 raise ValueError('value should be one of <%s>' % 360 '|'.join(self.enum_values)) 361 else: 362 return argument 363 else: 364 if argument.upper() not in [value.upper() for value in self.enum_values]: 365 raise ValueError('value should be one of <%s>' % 366 '|'.join(self.enum_values)) 367 else: 368 return [value for value in self.enum_values 369 if value.upper() == argument.upper()][0] 370 371 def flag_type(self) -> Text: 372 """See base class.""" 373 return 'string enum' 374 375 376class EnumClassParser(ArgumentParser[_ET]): 377 """Parser of an Enum class member.""" 378 379 def __init__( 380 self, enum_class: Type[_ET], case_sensitive: bool = True 381 ) -> None: 382 """Initializes EnumParser. 383 384 Args: 385 enum_class: class, the Enum class with all possible flag values. 386 case_sensitive: bool, whether or not the enum is to be case-sensitive. If 387 False, all member names must be unique when case is ignored. 388 389 Raises: 390 TypeError: When enum_class is not a subclass of Enum. 391 ValueError: When enum_class is empty. 392 """ 393 if not issubclass(enum_class, enum.Enum): 394 raise TypeError('{} is not a subclass of Enum.'.format(enum_class)) 395 if not enum_class.__members__: 396 raise ValueError('enum_class cannot be empty, but "{}" is empty.' 397 .format(enum_class)) 398 if not case_sensitive: 399 members = collections.Counter( 400 name.lower() for name in enum_class.__members__) 401 duplicate_keys = { 402 member for member, count in members.items() if count > 1 403 } 404 if duplicate_keys: 405 raise ValueError( 406 'Duplicate enum values for {} using case_sensitive=False'.format( 407 duplicate_keys)) 408 409 super(EnumClassParser, self).__init__() 410 self.enum_class = enum_class 411 self._case_sensitive = case_sensitive 412 if case_sensitive: 413 self._member_names = tuple(enum_class.__members__) 414 else: 415 self._member_names = tuple( 416 name.lower() for name in enum_class.__members__) 417 418 @property 419 def member_names(self) -> Sequence[Text]: 420 """The accepted enum names, in lowercase if not case sensitive.""" 421 return self._member_names 422 423 def parse(self, argument: Union[_ET, Text]) -> _ET: 424 """Determines validity of argument and returns the correct element of enum. 425 426 Args: 427 argument: str or Enum class member, the supplied flag value. 428 429 Returns: 430 The first matching Enum class member in Enum class. 431 432 Raises: 433 ValueError: Raised when argument didn't match anything in enum. 434 """ 435 if isinstance(argument, self.enum_class): 436 return argument # pytype: disable=bad-return-type 437 elif not isinstance(argument, str): 438 raise ValueError( 439 '{} is not an enum member or a name of a member in {}'.format( 440 argument, self.enum_class)) 441 key = EnumParser( 442 self._member_names, case_sensitive=self._case_sensitive).parse(argument) 443 if self._case_sensitive: 444 return self.enum_class[key] 445 else: 446 # If EnumParser.parse() return a value, we're guaranteed to find it 447 # as a member of the class 448 return next(value for name, value in self.enum_class.__members__.items() 449 if name.lower() == key.lower()) 450 451 def flag_type(self) -> Text: 452 """See base class.""" 453 return 'enum class' 454 455 456class ListSerializer(Generic[_T], ArgumentSerializer[List[_T]]): 457 458 def __init__(self, list_sep: Text) -> None: 459 self.list_sep = list_sep 460 461 def serialize(self, value: List[_T]) -> Text: 462 """See base class.""" 463 return self.list_sep.join([str(x) for x in value]) 464 465 466class EnumClassListSerializer(ListSerializer[_ET]): 467 """A serializer for :class:`MultiEnumClass` flags. 468 469 This serializer simply joins the output of `EnumClassSerializer` using a 470 provided separator. 471 """ 472 473 def __init__(self, list_sep: Text, **kwargs) -> None: 474 """Initializes EnumClassListSerializer. 475 476 Args: 477 list_sep: String to be used as a separator when serializing 478 **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize 479 individual values. 480 """ 481 super(EnumClassListSerializer, self).__init__(list_sep) 482 self._element_serializer = EnumClassSerializer(**kwargs) 483 484 def serialize(self, value: Union[_ET, List[_ET]]) -> Text: 485 """See base class.""" 486 if isinstance(value, list): 487 return self.list_sep.join( 488 self._element_serializer.serialize(x) for x in value) 489 else: 490 return self._element_serializer.serialize(value) 491 492 493class CsvListSerializer(ListSerializer[Text]): 494 495 def serialize(self, value: List[Text]) -> Text: 496 """Serializes a list as a CSV string or unicode.""" 497 output = io.StringIO() 498 writer = csv.writer(output, delimiter=self.list_sep) 499 writer.writerow([str(x) for x in value]) 500 serialized_value = output.getvalue().strip() 501 502 # We need the returned value to be pure ascii or Unicodes so that 503 # when the xml help is generated they are usefully encodable. 504 return str(serialized_value) 505 506 507class EnumClassSerializer(ArgumentSerializer[_ET]): 508 """Class for generating string representations of an enum class flag value.""" 509 510 def __init__(self, lowercase: bool) -> None: 511 """Initializes EnumClassSerializer. 512 513 Args: 514 lowercase: If True, enum member names are lowercased during serialization. 515 """ 516 self._lowercase = lowercase 517 518 def serialize(self, value: _ET) -> Text: 519 """Returns a serialized string of the Enum class value.""" 520 as_string = str(value.name) 521 return as_string.lower() if self._lowercase else as_string 522 523 524class BaseListParser(ArgumentParser): 525 """Base class for a parser of lists of strings. 526 527 To extend, inherit from this class; from the subclass ``__init__``, call:: 528 529 super().__init__(token, name) 530 531 where token is a character used to tokenize, and name is a description 532 of the separator. 533 """ 534 535 def __init__( 536 self, token: Optional[Text] = None, name: Optional[Text] = None 537 ) -> None: 538 assert name 539 super(BaseListParser, self).__init__() 540 self._token = token 541 self._name = name 542 self.syntactic_help = 'a %s separated list' % self._name 543 544 def parse(self, argument: Text) -> List[Text]: 545 """See base class.""" 546 if isinstance(argument, list): 547 return argument 548 elif not argument: 549 return [] 550 else: 551 return [s.strip() for s in argument.split(self._token)] 552 553 def flag_type(self) -> Text: 554 """See base class.""" 555 return '%s separated list of strings' % self._name 556 557 558class ListParser(BaseListParser): 559 """Parser for a comma-separated list of strings.""" 560 561 def __init__(self) -> None: 562 super(ListParser, self).__init__(',', 'comma') 563 564 def parse(self, argument: Union[Text, List[Text]]) -> List[Text]: 565 """Parses argument as comma-separated list of strings.""" 566 if isinstance(argument, list): 567 return argument 568 elif not argument: 569 return [] 570 else: 571 try: 572 return [s.strip() for s in list(csv.reader([argument], strict=True))[0]] 573 except csv.Error as e: 574 # Provide a helpful report for case like 575 # --listflag="$(printf 'hello,\nworld')" 576 # IOW, list flag values containing naked newlines. This error 577 # was previously "reported" by allowing csv.Error to 578 # propagate. 579 raise ValueError('Unable to parse the value %r as a %s: %s' 580 % (argument, self.flag_type(), e)) 581 582 def _custom_xml_dom_elements( 583 self, doc: minidom.Document 584 ) -> List[minidom.Element]: 585 elements = super(ListParser, self)._custom_xml_dom_elements(doc) 586 elements.append(_helpers.create_xml_dom_element( 587 doc, 'list_separator', repr(','))) 588 return elements 589 590 591class WhitespaceSeparatedListParser(BaseListParser): 592 """Parser for a whitespace-separated list of strings.""" 593 594 def __init__(self, comma_compat: bool = False) -> None: 595 """Initializer. 596 597 Args: 598 comma_compat: bool, whether to support comma as an additional separator. 599 If False then only whitespace is supported. This is intended only for 600 backwards compatibility with flags that used to be comma-separated. 601 """ 602 self._comma_compat = comma_compat 603 name = 'whitespace or comma' if self._comma_compat else 'whitespace' 604 super(WhitespaceSeparatedListParser, self).__init__(None, name) 605 606 def parse(self, argument: Union[Text, List[Text]]) -> List[Text]: 607 """Parses argument as whitespace-separated list of strings. 608 609 It also parses argument as comma-separated list of strings if requested. 610 611 Args: 612 argument: string argument passed in the commandline. 613 614 Returns: 615 [str], the parsed flag value. 616 """ 617 if isinstance(argument, list): 618 return argument 619 elif not argument: 620 return [] 621 else: 622 if self._comma_compat: 623 argument = argument.replace(',', ' ') 624 return argument.split() 625 626 def _custom_xml_dom_elements( 627 self, doc: minidom.Document 628 ) -> List[minidom.Element]: 629 elements = super(WhitespaceSeparatedListParser, self 630 )._custom_xml_dom_elements(doc) 631 separators = list(string.whitespace) 632 if self._comma_compat: 633 separators.append(',') 634 separators.sort() 635 for sep_char in separators: 636 elements.append(_helpers.create_xml_dom_element( 637 doc, 'list_separator', repr(sep_char))) 638 return elements 639