1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Contains base classes used to parse and convert arguments.
16
17Do NOT import this module directly. Import the flags package and use the
18aliases defined at the package level instead.
19"""
20
21import collections
22import csv
23import enum
24import io
25import string
26from typing import Generic, List, Iterable, Optional, Sequence, Text, Type, TypeVar, Union
27from xml.dom import minidom
28
29from absl.flags import _helpers
30
31_T = TypeVar('_T')
32_ET = TypeVar('_ET', bound=enum.Enum)
33_N = TypeVar('_N', int, float)
34
35
36def _is_integer_type(instance):
37  """Returns True if instance is an integer, and not a bool."""
38  return (isinstance(instance, int) and
39          not isinstance(instance, bool))
40
41
42class _ArgumentParserCache(type):
43  """Metaclass used to cache and share argument parsers among flags."""
44
45  _instances = {}
46
47  def __call__(cls, *args, **kwargs):
48    """Returns an instance of the argument parser cls.
49
50    This method overrides behavior of the __new__ methods in
51    all subclasses of ArgumentParser (inclusive). If an instance
52    for cls with the same set of arguments exists, this instance is
53    returned, otherwise a new instance is created.
54
55    If any keyword arguments are defined, or the values in args
56    are not hashable, this method always returns a new instance of
57    cls.
58
59    Args:
60      *args: Positional initializer arguments.
61      **kwargs: Initializer keyword arguments.
62
63    Returns:
64      An instance of cls, shared or new.
65    """
66    if kwargs:
67      return type.__call__(cls, *args, **kwargs)
68    else:
69      instances = cls._instances
70      key = (cls,) + tuple(args)
71      try:
72        return instances[key]
73      except KeyError:
74        # No cache entry for key exists, create a new one.
75        return instances.setdefault(key, type.__call__(cls, *args))
76      except TypeError:
77        # An object in args cannot be hashed, always return
78        # a new instance.
79        return type.__call__(cls, *args)
80
81
82class ArgumentParser(Generic[_T], metaclass=_ArgumentParserCache):
83  """Base class used to parse and convert arguments.
84
85  The :meth:`parse` method checks to make sure that the string argument is a
86  legal value and convert it to a native type.  If the value cannot be
87  converted, it should throw a ``ValueError`` exception with a human
88  readable explanation of why the value is illegal.
89
90  Subclasses should also define a syntactic_help string which may be
91  presented to the user to describe the form of the legal values.
92
93  Argument parser classes must be stateless, since instances are cached
94  and shared between flags. Initializer arguments are allowed, but all
95  member variables must be derived from initializer arguments only.
96  """
97
98  syntactic_help: Text = ''
99
100  def parse(self, argument: Text) -> Optional[_T]:
101    """Parses the string argument and returns the native value.
102
103    By default it returns its argument unmodified.
104
105    Args:
106      argument: string argument passed in the commandline.
107
108    Raises:
109      ValueError: Raised when it fails to parse the argument.
110      TypeError: Raised when the argument has the wrong type.
111
112    Returns:
113      The parsed value in native type.
114    """
115    if not isinstance(argument, str):
116      raise TypeError('flag value must be a string, found "{}"'.format(
117          type(argument)))
118    return argument
119
120  def flag_type(self) -> Text:
121    """Returns a string representing the type of the flag."""
122    return 'string'
123
124  def _custom_xml_dom_elements(
125      self, doc: minidom.Document
126  ) -> List[minidom.Element]:
127    """Returns a list of minidom.Element to add additional flag information.
128
129    Args:
130      doc: minidom.Document, the DOM document it should create nodes from.
131    """
132    del doc  # Unused.
133    return []
134
135
136class ArgumentSerializer(Generic[_T]):
137  """Base class for generating string representations of a flag value."""
138
139  def serialize(self, value: _T) -> Text:
140    """Returns a serialized string of the value."""
141    return str(value)
142
143
144class NumericParser(ArgumentParser[_N]):
145  """Parser of numeric values.
146
147  Parsed value may be bounded to a given upper and lower bound.
148  """
149
150  lower_bound: Optional[_N]
151  upper_bound: Optional[_N]
152
153  def is_outside_bounds(self, val: _N) -> bool:
154    """Returns whether the value is outside the bounds or not."""
155    return ((self.lower_bound is not None and val < self.lower_bound) or
156            (self.upper_bound is not None and val > self.upper_bound))
157
158  def parse(self, argument: Text) -> _N:
159    """See base class."""
160    val = self.convert(argument)
161    if self.is_outside_bounds(val):
162      raise ValueError('%s is not %s' % (val, self.syntactic_help))
163    return val
164
165  def _custom_xml_dom_elements(
166      self, doc: minidom.Document
167  ) -> List[minidom.Element]:
168    elements = []
169    if self.lower_bound is not None:
170      elements.append(_helpers.create_xml_dom_element(
171          doc, 'lower_bound', self.lower_bound))
172    if self.upper_bound is not None:
173      elements.append(_helpers.create_xml_dom_element(
174          doc, 'upper_bound', self.upper_bound))
175    return elements
176
177  def convert(self, argument: Text) -> _N:
178    """Returns the correct numeric value of argument.
179
180    Subclass must implement this method, and raise TypeError if argument is not
181    string or has the right numeric type.
182
183    Args:
184      argument: string argument passed in the commandline, or the numeric type.
185
186    Raises:
187      TypeError: Raised when argument is not a string or the right numeric type.
188      ValueError: Raised when failed to convert argument to the numeric value.
189    """
190    raise NotImplementedError
191
192
193class FloatParser(NumericParser[float]):
194  """Parser of floating point values.
195
196  Parsed value may be bounded to a given upper and lower bound.
197  """
198  number_article = 'a'
199  number_name = 'number'
200  syntactic_help = ' '.join((number_article, number_name))
201
202  def __init__(
203      self,
204      lower_bound: Optional[float] = None,
205      upper_bound: Optional[float] = None,
206  ) -> None:
207    super(FloatParser, self).__init__()
208    self.lower_bound = lower_bound
209    self.upper_bound = upper_bound
210    sh = self.syntactic_help
211    if lower_bound is not None and upper_bound is not None:
212      sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
213    elif lower_bound == 0:
214      sh = 'a non-negative %s' % self.number_name
215    elif upper_bound == 0:
216      sh = 'a non-positive %s' % self.number_name
217    elif upper_bound is not None:
218      sh = '%s <= %s' % (self.number_name, upper_bound)
219    elif lower_bound is not None:
220      sh = '%s >= %s' % (self.number_name, lower_bound)
221    self.syntactic_help = sh
222
223  def convert(self, argument: Union[int, float, str]) -> float:
224    """Returns the float value of argument."""
225    if (_is_integer_type(argument) or isinstance(argument, float) or
226        isinstance(argument, str)):
227      return float(argument)
228    else:
229      raise TypeError(
230          'Expect argument to be a string, int, or float, found {}'.format(
231              type(argument)))
232
233  def flag_type(self) -> Text:
234    """See base class."""
235    return 'float'
236
237
238class IntegerParser(NumericParser[int]):
239  """Parser of an integer value.
240
241  Parsed value may be bounded to a given upper and lower bound.
242  """
243  number_article = 'an'
244  number_name = 'integer'
245  syntactic_help = ' '.join((number_article, number_name))
246
247  def __init__(
248      self, lower_bound: Optional[int] = None, upper_bound: Optional[int] = None
249  ) -> None:
250    super(IntegerParser, self).__init__()
251    self.lower_bound = lower_bound
252    self.upper_bound = upper_bound
253    sh = self.syntactic_help
254    if lower_bound is not None and upper_bound is not None:
255      sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
256    elif lower_bound == 1:
257      sh = 'a positive %s' % self.number_name
258    elif upper_bound == -1:
259      sh = 'a negative %s' % self.number_name
260    elif lower_bound == 0:
261      sh = 'a non-negative %s' % self.number_name
262    elif upper_bound == 0:
263      sh = 'a non-positive %s' % self.number_name
264    elif upper_bound is not None:
265      sh = '%s <= %s' % (self.number_name, upper_bound)
266    elif lower_bound is not None:
267      sh = '%s >= %s' % (self.number_name, lower_bound)
268    self.syntactic_help = sh
269
270  def convert(self, argument: Union[int, Text]) -> int:
271    """Returns the int value of argument."""
272    if _is_integer_type(argument):
273      return argument
274    elif isinstance(argument, str):
275      base = 10
276      if len(argument) > 2 and argument[0] == '0':
277        if argument[1] == 'o':
278          base = 8
279        elif argument[1] == 'x':
280          base = 16
281      return int(argument, base)
282    else:
283      raise TypeError('Expect argument to be a string or int, found {}'.format(
284          type(argument)))
285
286  def flag_type(self) -> Text:
287    """See base class."""
288    return 'int'
289
290
291class BooleanParser(ArgumentParser[bool]):
292  """Parser of boolean values."""
293
294  def parse(self, argument: Union[Text, int]) -> bool:
295    """See base class."""
296    if isinstance(argument, str):
297      if argument.lower() in ('true', 't', '1'):
298        return True
299      elif argument.lower() in ('false', 'f', '0'):
300        return False
301      else:
302        raise ValueError('Non-boolean argument to boolean flag', argument)
303    elif isinstance(argument, int):
304      # Only allow bool or integer 0, 1.
305      # Note that float 1.0 == True, 0.0 == False.
306      bool_value = bool(argument)
307      if argument == bool_value:
308        return bool_value
309      else:
310        raise ValueError('Non-boolean argument to boolean flag', argument)
311
312    raise TypeError('Non-boolean argument to boolean flag', argument)
313
314  def flag_type(self) -> Text:
315    """See base class."""
316    return 'bool'
317
318
319class EnumParser(ArgumentParser[Text]):
320  """Parser of a string enum value (a string value from a given set)."""
321
322  def __init__(
323      self, enum_values: Iterable[Text], case_sensitive: bool = True
324  ) -> None:
325    """Initializes EnumParser.
326
327    Args:
328      enum_values: [str], a non-empty list of string values in the enum.
329      case_sensitive: bool, whether or not the enum is to be case-sensitive.
330
331    Raises:
332      ValueError: When enum_values is empty.
333    """
334    if not enum_values:
335      raise ValueError(
336          'enum_values cannot be empty, found "{}"'.format(enum_values))
337    if isinstance(enum_values, str):
338      raise ValueError(
339          'enum_values cannot be a str, found "{}"'.format(enum_values)
340      )
341    super(EnumParser, self).__init__()
342    self.enum_values = list(enum_values)
343    self.case_sensitive = case_sensitive
344
345  def parse(self, argument: Text) -> Text:
346    """Determines validity of argument and returns the correct element of enum.
347
348    Args:
349      argument: str, the supplied flag value.
350
351    Returns:
352      The first matching element from enum_values.
353
354    Raises:
355      ValueError: Raised when argument didn't match anything in enum.
356    """
357    if self.case_sensitive:
358      if argument not in self.enum_values:
359        raise ValueError('value should be one of <%s>' %
360                         '|'.join(self.enum_values))
361      else:
362        return argument
363    else:
364      if argument.upper() not in [value.upper() for value in self.enum_values]:
365        raise ValueError('value should be one of <%s>' %
366                         '|'.join(self.enum_values))
367      else:
368        return [value for value in self.enum_values
369                if value.upper() == argument.upper()][0]
370
371  def flag_type(self) -> Text:
372    """See base class."""
373    return 'string enum'
374
375
376class EnumClassParser(ArgumentParser[_ET]):
377  """Parser of an Enum class member."""
378
379  def __init__(
380      self, enum_class: Type[_ET], case_sensitive: bool = True
381  ) -> None:
382    """Initializes EnumParser.
383
384    Args:
385      enum_class: class, the Enum class with all possible flag values.
386      case_sensitive: bool, whether or not the enum is to be case-sensitive. If
387        False, all member names must be unique when case is ignored.
388
389    Raises:
390      TypeError: When enum_class is not a subclass of Enum.
391      ValueError: When enum_class is empty.
392    """
393    if not issubclass(enum_class, enum.Enum):
394      raise TypeError('{} is not a subclass of Enum.'.format(enum_class))
395    if not enum_class.__members__:
396      raise ValueError('enum_class cannot be empty, but "{}" is empty.'
397                       .format(enum_class))
398    if not case_sensitive:
399      members = collections.Counter(
400          name.lower() for name in enum_class.__members__)
401      duplicate_keys = {
402          member for member, count in members.items() if count > 1
403      }
404      if duplicate_keys:
405        raise ValueError(
406            'Duplicate enum values for {} using case_sensitive=False'.format(
407                duplicate_keys))
408
409    super(EnumClassParser, self).__init__()
410    self.enum_class = enum_class
411    self._case_sensitive = case_sensitive
412    if case_sensitive:
413      self._member_names = tuple(enum_class.__members__)
414    else:
415      self._member_names = tuple(
416          name.lower() for name in enum_class.__members__)
417
418  @property
419  def member_names(self) -> Sequence[Text]:
420    """The accepted enum names, in lowercase if not case sensitive."""
421    return self._member_names
422
423  def parse(self, argument: Union[_ET, Text]) -> _ET:
424    """Determines validity of argument and returns the correct element of enum.
425
426    Args:
427      argument: str or Enum class member, the supplied flag value.
428
429    Returns:
430      The first matching Enum class member in Enum class.
431
432    Raises:
433      ValueError: Raised when argument didn't match anything in enum.
434    """
435    if isinstance(argument, self.enum_class):
436      return argument  # pytype: disable=bad-return-type
437    elif not isinstance(argument, str):
438      raise ValueError(
439          '{} is not an enum member or a name of a member in {}'.format(
440              argument, self.enum_class))
441    key = EnumParser(
442        self._member_names, case_sensitive=self._case_sensitive).parse(argument)
443    if self._case_sensitive:
444      return self.enum_class[key]
445    else:
446      # If EnumParser.parse() return a value, we're guaranteed to find it
447      # as a member of the class
448      return next(value for name, value in self.enum_class.__members__.items()
449                  if name.lower() == key.lower())
450
451  def flag_type(self) -> Text:
452    """See base class."""
453    return 'enum class'
454
455
456class ListSerializer(Generic[_T], ArgumentSerializer[List[_T]]):
457
458  def __init__(self, list_sep: Text) -> None:
459    self.list_sep = list_sep
460
461  def serialize(self, value: List[_T]) -> Text:
462    """See base class."""
463    return self.list_sep.join([str(x) for x in value])
464
465
466class EnumClassListSerializer(ListSerializer[_ET]):
467  """A serializer for :class:`MultiEnumClass` flags.
468
469  This serializer simply joins the output of `EnumClassSerializer` using a
470  provided separator.
471  """
472
473  def __init__(self, list_sep: Text, **kwargs) -> None:
474    """Initializes EnumClassListSerializer.
475
476    Args:
477      list_sep: String to be used as a separator when serializing
478      **kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize
479        individual values.
480    """
481    super(EnumClassListSerializer, self).__init__(list_sep)
482    self._element_serializer = EnumClassSerializer(**kwargs)
483
484  def serialize(self, value: Union[_ET, List[_ET]]) -> Text:
485    """See base class."""
486    if isinstance(value, list):
487      return self.list_sep.join(
488          self._element_serializer.serialize(x) for x in value)
489    else:
490      return self._element_serializer.serialize(value)
491
492
493class CsvListSerializer(ListSerializer[Text]):
494
495  def serialize(self, value: List[Text]) -> Text:
496    """Serializes a list as a CSV string or unicode."""
497    output = io.StringIO()
498    writer = csv.writer(output, delimiter=self.list_sep)
499    writer.writerow([str(x) for x in value])
500    serialized_value = output.getvalue().strip()
501
502    # We need the returned value to be pure ascii or Unicodes so that
503    # when the xml help is generated they are usefully encodable.
504    return str(serialized_value)
505
506
507class EnumClassSerializer(ArgumentSerializer[_ET]):
508  """Class for generating string representations of an enum class flag value."""
509
510  def __init__(self, lowercase: bool) -> None:
511    """Initializes EnumClassSerializer.
512
513    Args:
514      lowercase: If True, enum member names are lowercased during serialization.
515    """
516    self._lowercase = lowercase
517
518  def serialize(self, value: _ET) -> Text:
519    """Returns a serialized string of the Enum class value."""
520    as_string = str(value.name)
521    return as_string.lower() if self._lowercase else as_string
522
523
524class BaseListParser(ArgumentParser):
525  """Base class for a parser of lists of strings.
526
527  To extend, inherit from this class; from the subclass ``__init__``, call::
528
529      super().__init__(token, name)
530
531  where token is a character used to tokenize, and name is a description
532  of the separator.
533  """
534
535  def __init__(
536      self, token: Optional[Text] = None, name: Optional[Text] = None
537  ) -> None:
538    assert name
539    super(BaseListParser, self).__init__()
540    self._token = token
541    self._name = name
542    self.syntactic_help = 'a %s separated list' % self._name
543
544  def parse(self, argument: Text) -> List[Text]:
545    """See base class."""
546    if isinstance(argument, list):
547      return argument
548    elif not argument:
549      return []
550    else:
551      return [s.strip() for s in argument.split(self._token)]
552
553  def flag_type(self) -> Text:
554    """See base class."""
555    return '%s separated list of strings' % self._name
556
557
558class ListParser(BaseListParser):
559  """Parser for a comma-separated list of strings."""
560
561  def __init__(self) -> None:
562    super(ListParser, self).__init__(',', 'comma')
563
564  def parse(self, argument: Union[Text, List[Text]]) -> List[Text]:
565    """Parses argument as comma-separated list of strings."""
566    if isinstance(argument, list):
567      return argument
568    elif not argument:
569      return []
570    else:
571      try:
572        return [s.strip() for s in list(csv.reader([argument], strict=True))[0]]
573      except csv.Error as e:
574        # Provide a helpful report for case like
575        #   --listflag="$(printf 'hello,\nworld')"
576        # IOW, list flag values containing naked newlines.  This error
577        # was previously "reported" by allowing csv.Error to
578        # propagate.
579        raise ValueError('Unable to parse the value %r as a %s: %s'
580                         % (argument, self.flag_type(), e))
581
582  def _custom_xml_dom_elements(
583      self, doc: minidom.Document
584  ) -> List[minidom.Element]:
585    elements = super(ListParser, self)._custom_xml_dom_elements(doc)
586    elements.append(_helpers.create_xml_dom_element(
587        doc, 'list_separator', repr(',')))
588    return elements
589
590
591class WhitespaceSeparatedListParser(BaseListParser):
592  """Parser for a whitespace-separated list of strings."""
593
594  def __init__(self, comma_compat: bool = False) -> None:
595    """Initializer.
596
597    Args:
598      comma_compat: bool, whether to support comma as an additional separator.
599          If False then only whitespace is supported.  This is intended only for
600          backwards compatibility with flags that used to be comma-separated.
601    """
602    self._comma_compat = comma_compat
603    name = 'whitespace or comma' if self._comma_compat else 'whitespace'
604    super(WhitespaceSeparatedListParser, self).__init__(None, name)
605
606  def parse(self, argument: Union[Text, List[Text]]) -> List[Text]:
607    """Parses argument as whitespace-separated list of strings.
608
609    It also parses argument as comma-separated list of strings if requested.
610
611    Args:
612      argument: string argument passed in the commandline.
613
614    Returns:
615      [str], the parsed flag value.
616    """
617    if isinstance(argument, list):
618      return argument
619    elif not argument:
620      return []
621    else:
622      if self._comma_compat:
623        argument = argument.replace(',', ' ')
624      return argument.split()
625
626  def _custom_xml_dom_elements(
627      self, doc: minidom.Document
628  ) -> List[minidom.Element]:
629    elements = super(WhitespaceSeparatedListParser, self
630                    )._custom_xml_dom_elements(doc)
631    separators = list(string.whitespace)
632    if self._comma_compat:
633      separators.append(',')
634    separators.sort()
635    for sep_char in separators:
636      elements.append(_helpers.create_xml_dom_element(
637          doc, 'list_separator', repr(sep_char)))
638    return elements
639