1# Copyright 2021 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Defines *private* classes used for flag validators.
16
17Do NOT import this module. DO NOT use anything from this module. They are
18private APIs.
19"""
20
21from absl.flags import _exceptions
22
23
24class Validator(object):
25  """Base class for flags validators.
26
27  Users should NOT overload these classes, and use flags.Register...
28  methods instead.
29  """
30
31  # Used to assign each validator an unique insertion_index
32  validators_count = 0
33
34  def __init__(self, checker, message):
35    """Constructor to create all validators.
36
37    Args:
38      checker: function to verify the constraint.
39          Input of this method varies, see SingleFlagValidator and
40          multi_flags_validator for a detailed description.
41      message: str, error message to be shown to the user.
42    """
43    self.checker = checker
44    self.message = message
45    Validator.validators_count += 1
46    # Used to assert validators in the order they were registered.
47    self.insertion_index = Validator.validators_count
48
49  def verify(self, flag_values):
50    """Verifies that constraint is satisfied.
51
52    flags library calls this method to verify Validator's constraint.
53
54    Args:
55      flag_values: flags.FlagValues, the FlagValues instance to get flags from.
56    Raises:
57      Error: Raised if constraint is not satisfied.
58    """
59    param = self._get_input_to_checker_function(flag_values)
60    if not self.checker(param):
61      raise _exceptions.ValidationError(self.message)
62
63  def get_flags_names(self):
64    """Returns the names of the flags checked by this validator.
65
66    Returns:
67      [string], names of the flags.
68    """
69    raise NotImplementedError('This method should be overloaded')
70
71  def print_flags_with_values(self, flag_values):
72    raise NotImplementedError('This method should be overloaded')
73
74  def _get_input_to_checker_function(self, flag_values):
75    """Given flag values, returns the input to be given to checker.
76
77    Args:
78      flag_values: flags.FlagValues, containing all flags.
79    Returns:
80      The input to be given to checker. The return type depends on the specific
81      validator.
82    """
83    raise NotImplementedError('This method should be overloaded')
84
85
86class SingleFlagValidator(Validator):
87  """Validator behind register_validator() method.
88
89  Validates that a single flag passes its checker function. The checker function
90  takes the flag value and returns True (if value looks fine) or, if flag value
91  is not valid, either returns False or raises an Exception.
92  """
93
94  def __init__(self, flag_name, checker, message):
95    """Constructor.
96
97    Args:
98      flag_name: string, name of the flag.
99      checker: function to verify the validator.
100          input  - value of the corresponding flag (string, boolean, etc).
101          output - bool, True if validator constraint is satisfied.
102              If constraint is not satisfied, it should either return False or
103              raise flags.ValidationError(desired_error_message).
104      message: str, error message to be shown to the user if validator's
105          condition is not satisfied.
106    """
107    super(SingleFlagValidator, self).__init__(checker, message)
108    self.flag_name = flag_name
109
110  def get_flags_names(self):
111    return [self.flag_name]
112
113  def print_flags_with_values(self, flag_values):
114    return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value)
115
116  def _get_input_to_checker_function(self, flag_values):
117    """Given flag values, returns the input to be given to checker.
118
119    Args:
120      flag_values: flags.FlagValues, the FlagValues instance to get flags from.
121    Returns:
122      object, the input to be given to checker.
123    """
124    return flag_values[self.flag_name].value
125
126
127class MultiFlagsValidator(Validator):
128  """Validator behind register_multi_flags_validator method.
129
130  Validates that flag values pass their common checker function. The checker
131  function takes flag values and returns True (if values look fine) or,
132  if values are not valid, either returns False or raises an Exception.
133  """
134
135  def __init__(self, flag_names, checker, message):
136    """Constructor.
137
138    Args:
139      flag_names: [str], containing names of the flags used by checker.
140      checker: function to verify the validator.
141          input  - dict, with keys() being flag_names, and value for each
142              key being the value of the corresponding flag (string, boolean,
143              etc).
144          output - bool, True if validator constraint is satisfied.
145              If constraint is not satisfied, it should either return False or
146              raise flags.ValidationError(desired_error_message).
147      message: str, error message to be shown to the user if validator's
148          condition is not satisfied
149    """
150    super(MultiFlagsValidator, self).__init__(checker, message)
151    self.flag_names = flag_names
152
153  def _get_input_to_checker_function(self, flag_values):
154    """Given flag values, returns the input to be given to checker.
155
156    Args:
157      flag_values: flags.FlagValues, the FlagValues instance to get flags from.
158    Returns:
159      dict, with keys() being self.flag_names, and value for each key
160      being the value of the corresponding flag (string, boolean, etc).
161    """
162    return dict([key, flag_values[key].value] for key in self.flag_names)
163
164  def print_flags_with_values(self, flag_values):
165    prefix = 'flags '
166    flags_with_values = []
167    for key in self.flag_names:
168      flags_with_values.append('%s=%s' % (key, flag_values[key].value))
169    return prefix + ', '.join(flags_with_values)
170
171  def get_flags_names(self):
172    return self.flag_names
173