1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Decorator and context manager for saving and restoring flag values.
16
17There are many ways to save and restore.  Always use the most convenient method
18for a given use case.
19
20Here are examples of each method.  They all call ``do_stuff()`` while
21``FLAGS.someflag`` is temporarily set to ``'foo'``::
22
23    from absl.testing import flagsaver
24
25    # Use a decorator which can optionally override flags via arguments.
26    @flagsaver.flagsaver(someflag='foo')
27    def some_func():
28      do_stuff()
29
30    # Use a decorator which can optionally override flags with flagholders.
31    @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23))
32    def some_func():
33      do_stuff()
34
35    # Use a decorator which does not override flags itself.
36    @flagsaver.flagsaver
37    def some_func():
38      FLAGS.someflag = 'foo'
39      do_stuff()
40
41    # Use a context manager which can optionally override flags via arguments.
42    with flagsaver.flagsaver(someflag='foo'):
43      do_stuff()
44
45    # Save and restore the flag values yourself.
46    saved_flag_values = flagsaver.save_flag_values()
47    try:
48      FLAGS.someflag = 'foo'
49      do_stuff()
50    finally:
51      flagsaver.restore_flag_values(saved_flag_values)
52
53    # Use the parsing version to emulate users providing the flags.
54    # Note that all flags must be provided as strings (unparsed).
55    @flagsaver.as_parsed(some_int_flag='123')
56    def some_func():
57      # Because the flag was parsed it is considered "present".
58      assert FLAGS.some_int_flag.present
59      do_stuff()
60
61    # flagsaver.as_parsed() can also be used as a context manager just like
62    # flagsaver.flagsaver()
63    with flagsaver.as_parsed(some_int_flag='123'):
64      do_stuff()
65
66    # The flagsaver.as_parsed() interface also supports FlagHolder objects.
67    @flagsaver.as_parsed((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, '23'))
68    def some_func():
69      do_stuff()
70
71    # Using as_parsed with a multi_X flag requires a sequence of strings.
72    @flagsaver.as_parsed(some_multi_int_flag=['123', '456'])
73    def some_func():
74      assert FLAGS.some_multi_int_flag.present
75      do_stuff()
76
77    # If a flag name includes non-identifier characters it can be specified like
78    # so:
79    @flagsaver.as_parsed(**{'i-like-dashes': 'true'})
80    def some_func():
81      do_stuff()
82
83We save and restore a shallow copy of each Flag object's ``__dict__`` attribute.
84This preserves all attributes of the flag, such as whether or not it was
85overridden from its default value.
86
87WARNING: Currently a flag that is saved and then deleted cannot be restored.  An
88exception will be raised.  However if you *add* a flag after saving flag values,
89and then restore flag values, the added flag will be deleted with no errors.
90"""
91
92import collections
93import functools
94import inspect
95from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union
96
97from absl import flags
98
99FLAGS = flags.FLAGS
100
101
102# The type of pre/post wrapped functions.
103_CallableT = TypeVar('_CallableT', bound=Callable)
104
105
106@overload
107def flagsaver(*args: Tuple[flags.FlagHolder, Any],
108              **kwargs: Any) -> '_FlagOverrider':
109  ...
110
111
112@overload
113def flagsaver(func: _CallableT) -> _CallableT:
114  ...
115
116
117def flagsaver(*args, **kwargs):
118  """The main flagsaver interface. See module doc for usage."""
119  return _construct_overrider(_FlagOverrider, *args, **kwargs)
120
121
122@overload
123def as_parsed(*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
124              **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
125  ...
126
127
128@overload
129def as_parsed(func: _CallableT) -> _CallableT:
130  ...
131
132
133def as_parsed(*args, **kwargs):
134  """Overrides flags by parsing strings, saves flag state similar to flagsaver.
135
136  This function can be used as either a decorator or context manager similar to
137  flagsaver.flagsaver(). However, where flagsaver.flagsaver() directly sets the
138  flags to new values, this function will parse the provided arguments as if
139  they were provided on the command line. Among other things, this will cause
140  `FLAGS['flag_name'].present == True`.
141
142  A note on unparsed input: For many flag types, the unparsed version will be
143  a single string. However for multi_x (multi_string, multi_integer, multi_enum)
144  the unparsed version will be a Sequence of strings.
145
146  Args:
147    *args: Tuples of FlagHolders and their unparsed value.
148    **kwargs: The keyword args are flag names, and the values are unparsed
149      values.
150
151  Returns:
152    _ParsingFlagOverrider that serves as a context manager or decorator. Will
153    save previous flag state and parse new flags, then on cleanup it will
154    restore the previous flag state.
155  """
156  return _construct_overrider(_ParsingFlagOverrider, *args, **kwargs)
157
158
159# NOTE: the order of these overload declarations matters. The type checker will
160# pick the first match which could be incorrect.
161@overload
162def _construct_overrider(
163    flag_overrider_cls: Type['_ParsingFlagOverrider'],
164    *args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
165    **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
166  ...
167
168
169@overload
170def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
171                         *args: Tuple[flags.FlagHolder, Any],
172                         **kwargs: Any) -> '_FlagOverrider':
173  ...
174
175
176@overload
177def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
178                         func: _CallableT) -> _CallableT:
179  ...
180
181
182def _construct_overrider(flag_overrider_cls, *args, **kwargs):
183  """Handles the args/kwargs returning an instance of flag_overrider_cls.
184
185  If flag_overrider_cls is _FlagOverrider then values should be native python
186  types matching the python types. Otherwise if flag_overrider_cls is
187  _ParsingFlagOverrider the values should be strings or sequences of strings.
188
189  Args:
190    flag_overrider_cls: The class that will do the overriding.
191    *args: Tuples of FlagHolder and the new flag value.
192    **kwargs: Keword args mapping flag name to new flag value.
193
194  Returns:
195    A _FlagOverrider to be used as a decorator or context manager.
196  """
197  if not args:
198    return flag_overrider_cls(**kwargs)
199  # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
200  if len(args) == 1 and callable(args[0]):
201    if kwargs:
202      raise ValueError(
203          "It's invalid to specify both positional and keyword parameters.")
204    func = args[0]
205    if inspect.isclass(func):
206      raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
207    return _wrap(flag_overrider_cls, func, {})
208  # args can be a list of (FlagHolder, value) pairs.
209  # In which case they augment any specified kwargs.
210  for arg in args:
211    if not isinstance(arg, tuple) or len(arg) != 2:
212      raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
213    holder, value = arg
214    if not isinstance(holder, flags.FlagHolder):
215      raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
216    if holder.name in kwargs:
217      raise ValueError('Cannot set --%s multiple times' % holder.name)
218    kwargs[holder.name] = value
219  return flag_overrider_cls(**kwargs)
220
221
222def save_flag_values(
223    flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
224  """Returns copy of flag values as a dict.
225
226  Args:
227    flag_values: FlagValues, the FlagValues instance with which the flag will be
228      saved. This should almost never need to be overridden.
229
230  Returns:
231    Dictionary mapping keys to values. Keys are flag names, values are
232    corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
233  """
234  return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
235
236
237def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
238                        flag_values: flags.FlagValues = FLAGS):
239  """Restores flag values based on the dictionary of flag values.
240
241  Args:
242    saved_flag_values: {'flag_name': value_dict, ...}
243    flag_values: FlagValues, the FlagValues instance from which the flag will be
244      restored. This should almost never need to be overridden.
245  """
246  new_flag_names = list(flag_values)
247  for name in new_flag_names:
248    saved = saved_flag_values.get(name)
249    if saved is None:
250      # If __dict__ was not saved delete "new" flag.
251      delattr(flag_values, name)
252    else:
253      if flag_values[name].value != saved['_value']:
254        flag_values[name].value = saved['_value']  # Ensure C++ value is set.
255      flag_values[name].__dict__ = saved
256
257
258@overload
259def _wrap(flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT,
260          overrides: Mapping[str, Any]) -> _CallableT:
261  ...
262
263
264@overload
265def _wrap(flag_overrider_cls: Type['_ParsingFlagOverrider'], func: _CallableT,
266          overrides: Mapping[str, Union[str, Sequence[str]]]) -> _CallableT:
267  ...
268
269
270def _wrap(flag_overrider_cls, func, overrides):
271  """Creates a wrapper function that saves/restores flag values.
272
273  Args:
274    flag_overrider_cls: The class that will be used as a context manager.
275    func: This will be called between saving flags and restoring flags.
276    overrides: Flag names mapped to their values. These flags will be set after
277      saving the original flag state. The type of the values depends on if
278      _FlagOverrider or _ParsingFlagOverrider was specified.
279
280  Returns:
281    A wrapped version of func.
282  """
283
284  @functools.wraps(func)
285  def _flagsaver_wrapper(*args, **kwargs):
286    """Wrapper function that saves and restores flags."""
287    with flag_overrider_cls(**overrides):
288      return func(*args, **kwargs)
289
290  return _flagsaver_wrapper
291
292
293class _FlagOverrider(object):
294  """Overrides flags for the duration of the decorated function call.
295
296  It also restores all original values of flags after decorated method
297  completes.
298  """
299
300  def __init__(self, **overrides: Any):
301    self._overrides = overrides
302    self._saved_flag_values = None
303
304  def __call__(self, func: _CallableT) -> _CallableT:
305    if inspect.isclass(func):
306      raise TypeError('flagsaver cannot be applied to a class.')
307    return _wrap(self.__class__, func, self._overrides)
308
309  def __enter__(self):
310    self._saved_flag_values = save_flag_values(FLAGS)
311    try:
312      FLAGS._set_attributes(**self._overrides)
313    except:
314      # It may fail because of flag validators.
315      restore_flag_values(self._saved_flag_values, FLAGS)
316      raise
317
318  def __exit__(self, exc_type, exc_value, traceback):
319    restore_flag_values(self._saved_flag_values, FLAGS)
320
321
322class _ParsingFlagOverrider(_FlagOverrider):
323  """Context manager for overriding flags.
324
325  Simulates command line parsing.
326
327  This is simlar to _FlagOverrider except that all **overrides should be
328  strings or sequences of strings, and when context is entered this class calls
329  .parse(value)
330
331  This results in the flags having .present set properly.
332  """
333
334  def __init__(self, **overrides: Union[str, Sequence[str]]):
335    for flag_name, new_value in overrides.items():
336      if isinstance(new_value, str):
337        continue
338      if (isinstance(new_value, collections.abc.Sequence) and
339          all(isinstance(single_value, str) for single_value in new_value)):
340        continue
341      raise TypeError(
342          f'flagsaver.as_parsed() cannot parse {flag_name}. Expected a single '
343          f'string or sequence of strings but {type(new_value)} was provided.')
344    super().__init__(**overrides)
345
346  def __enter__(self):
347    self._saved_flag_values = save_flag_values(FLAGS)
348    try:
349      for flag_name, unparsed_value in self._overrides.items():
350        # LINT.IfChange(flag_override_parsing)
351        FLAGS[flag_name].parse(unparsed_value)
352        FLAGS[flag_name].using_default_value = False
353        # LINT.ThenChange()
354
355      # Perform the validation on all modified flags. This is something that
356      # FLAGS._set_attributes() does for you in _FlagOverrider.
357      for flag_name in self._overrides:
358        FLAGS._assert_validators(FLAGS[flag_name].validators)
359
360    except KeyError as e:
361      # If a flag doesn't exist, an UnrecognizedFlagError is more specific.
362      restore_flag_values(self._saved_flag_values, FLAGS)
363      raise flags.UnrecognizedFlagError('Unknown command line flag.') from e
364
365    except:
366      # It may fail because of flag validators or general parsing issues.
367      restore_flag_values(self._saved_flag_values, FLAGS)
368      raise
369
370
371def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
372  """Returns a copy of the flag object's ``__dict__``.
373
374  It's mostly a shallow copy of the ``__dict__``, except it also does a shallow
375  copy of the validator list.
376
377  Args:
378    flag: flags.Flag, the flag to copy.
379
380  Returns:
381    A copy of the flag object's ``__dict__``.
382  """
383  copy = flag.__dict__.copy()
384  copy['_value'] = flag.value  # Ensure correct restore for C++ flags.
385  copy['validators'] = list(flag.validators)
386  return copy
387