1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Decorator and context manager for saving and restoring flag values. 16 17There are many ways to save and restore. Always use the most convenient method 18for a given use case. 19 20Here are examples of each method. They all call ``do_stuff()`` while 21``FLAGS.someflag`` is temporarily set to ``'foo'``:: 22 23 from absl.testing import flagsaver 24 25 # Use a decorator which can optionally override flags via arguments. 26 @flagsaver.flagsaver(someflag='foo') 27 def some_func(): 28 do_stuff() 29 30 # Use a decorator which can optionally override flags with flagholders. 31 @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23)) 32 def some_func(): 33 do_stuff() 34 35 # Use a decorator which does not override flags itself. 36 @flagsaver.flagsaver 37 def some_func(): 38 FLAGS.someflag = 'foo' 39 do_stuff() 40 41 # Use a context manager which can optionally override flags via arguments. 42 with flagsaver.flagsaver(someflag='foo'): 43 do_stuff() 44 45 # Save and restore the flag values yourself. 46 saved_flag_values = flagsaver.save_flag_values() 47 try: 48 FLAGS.someflag = 'foo' 49 do_stuff() 50 finally: 51 flagsaver.restore_flag_values(saved_flag_values) 52 53 # Use the parsing version to emulate users providing the flags. 54 # Note that all flags must be provided as strings (unparsed). 55 @flagsaver.as_parsed(some_int_flag='123') 56 def some_func(): 57 # Because the flag was parsed it is considered "present". 58 assert FLAGS.some_int_flag.present 59 do_stuff() 60 61 # flagsaver.as_parsed() can also be used as a context manager just like 62 # flagsaver.flagsaver() 63 with flagsaver.as_parsed(some_int_flag='123'): 64 do_stuff() 65 66 # The flagsaver.as_parsed() interface also supports FlagHolder objects. 67 @flagsaver.as_parsed((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, '23')) 68 def some_func(): 69 do_stuff() 70 71 # Using as_parsed with a multi_X flag requires a sequence of strings. 72 @flagsaver.as_parsed(some_multi_int_flag=['123', '456']) 73 def some_func(): 74 assert FLAGS.some_multi_int_flag.present 75 do_stuff() 76 77 # If a flag name includes non-identifier characters it can be specified like 78 # so: 79 @flagsaver.as_parsed(**{'i-like-dashes': 'true'}) 80 def some_func(): 81 do_stuff() 82 83We save and restore a shallow copy of each Flag object's ``__dict__`` attribute. 84This preserves all attributes of the flag, such as whether or not it was 85overridden from its default value. 86 87WARNING: Currently a flag that is saved and then deleted cannot be restored. An 88exception will be raised. However if you *add* a flag after saving flag values, 89and then restore flag values, the added flag will be deleted with no errors. 90""" 91 92import collections 93import functools 94import inspect 95from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union 96 97from absl import flags 98 99FLAGS = flags.FLAGS 100 101 102# The type of pre/post wrapped functions. 103_CallableT = TypeVar('_CallableT', bound=Callable) 104 105 106@overload 107def flagsaver(*args: Tuple[flags.FlagHolder, Any], 108 **kwargs: Any) -> '_FlagOverrider': 109 ... 110 111 112@overload 113def flagsaver(func: _CallableT) -> _CallableT: 114 ... 115 116 117def flagsaver(*args, **kwargs): 118 """The main flagsaver interface. See module doc for usage.""" 119 return _construct_overrider(_FlagOverrider, *args, **kwargs) 120 121 122@overload 123def as_parsed(*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]], 124 **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider': 125 ... 126 127 128@overload 129def as_parsed(func: _CallableT) -> _CallableT: 130 ... 131 132 133def as_parsed(*args, **kwargs): 134 """Overrides flags by parsing strings, saves flag state similar to flagsaver. 135 136 This function can be used as either a decorator or context manager similar to 137 flagsaver.flagsaver(). However, where flagsaver.flagsaver() directly sets the 138 flags to new values, this function will parse the provided arguments as if 139 they were provided on the command line. Among other things, this will cause 140 `FLAGS['flag_name'].present == True`. 141 142 A note on unparsed input: For many flag types, the unparsed version will be 143 a single string. However for multi_x (multi_string, multi_integer, multi_enum) 144 the unparsed version will be a Sequence of strings. 145 146 Args: 147 *args: Tuples of FlagHolders and their unparsed value. 148 **kwargs: The keyword args are flag names, and the values are unparsed 149 values. 150 151 Returns: 152 _ParsingFlagOverrider that serves as a context manager or decorator. Will 153 save previous flag state and parse new flags, then on cleanup it will 154 restore the previous flag state. 155 """ 156 return _construct_overrider(_ParsingFlagOverrider, *args, **kwargs) 157 158 159# NOTE: the order of these overload declarations matters. The type checker will 160# pick the first match which could be incorrect. 161@overload 162def _construct_overrider( 163 flag_overrider_cls: Type['_ParsingFlagOverrider'], 164 *args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]], 165 **kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider': 166 ... 167 168 169@overload 170def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'], 171 *args: Tuple[flags.FlagHolder, Any], 172 **kwargs: Any) -> '_FlagOverrider': 173 ... 174 175 176@overload 177def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'], 178 func: _CallableT) -> _CallableT: 179 ... 180 181 182def _construct_overrider(flag_overrider_cls, *args, **kwargs): 183 """Handles the args/kwargs returning an instance of flag_overrider_cls. 184 185 If flag_overrider_cls is _FlagOverrider then values should be native python 186 types matching the python types. Otherwise if flag_overrider_cls is 187 _ParsingFlagOverrider the values should be strings or sequences of strings. 188 189 Args: 190 flag_overrider_cls: The class that will do the overriding. 191 *args: Tuples of FlagHolder and the new flag value. 192 **kwargs: Keword args mapping flag name to new flag value. 193 194 Returns: 195 A _FlagOverrider to be used as a decorator or context manager. 196 """ 197 if not args: 198 return flag_overrider_cls(**kwargs) 199 # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)` 200 if len(args) == 1 and callable(args[0]): 201 if kwargs: 202 raise ValueError( 203 "It's invalid to specify both positional and keyword parameters.") 204 func = args[0] 205 if inspect.isclass(func): 206 raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') 207 return _wrap(flag_overrider_cls, func, {}) 208 # args can be a list of (FlagHolder, value) pairs. 209 # In which case they augment any specified kwargs. 210 for arg in args: 211 if not isinstance(arg, tuple) or len(arg) != 2: 212 raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) 213 holder, value = arg 214 if not isinstance(holder, flags.FlagHolder): 215 raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) 216 if holder.name in kwargs: 217 raise ValueError('Cannot set --%s multiple times' % holder.name) 218 kwargs[holder.name] = value 219 return flag_overrider_cls(**kwargs) 220 221 222def save_flag_values( 223 flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]: 224 """Returns copy of flag values as a dict. 225 226 Args: 227 flag_values: FlagValues, the FlagValues instance with which the flag will be 228 saved. This should almost never need to be overridden. 229 230 Returns: 231 Dictionary mapping keys to values. Keys are flag names, values are 232 corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``. 233 """ 234 return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} 235 236 237def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]], 238 flag_values: flags.FlagValues = FLAGS): 239 """Restores flag values based on the dictionary of flag values. 240 241 Args: 242 saved_flag_values: {'flag_name': value_dict, ...} 243 flag_values: FlagValues, the FlagValues instance from which the flag will be 244 restored. This should almost never need to be overridden. 245 """ 246 new_flag_names = list(flag_values) 247 for name in new_flag_names: 248 saved = saved_flag_values.get(name) 249 if saved is None: 250 # If __dict__ was not saved delete "new" flag. 251 delattr(flag_values, name) 252 else: 253 if flag_values[name].value != saved['_value']: 254 flag_values[name].value = saved['_value'] # Ensure C++ value is set. 255 flag_values[name].__dict__ = saved 256 257 258@overload 259def _wrap(flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT, 260 overrides: Mapping[str, Any]) -> _CallableT: 261 ... 262 263 264@overload 265def _wrap(flag_overrider_cls: Type['_ParsingFlagOverrider'], func: _CallableT, 266 overrides: Mapping[str, Union[str, Sequence[str]]]) -> _CallableT: 267 ... 268 269 270def _wrap(flag_overrider_cls, func, overrides): 271 """Creates a wrapper function that saves/restores flag values. 272 273 Args: 274 flag_overrider_cls: The class that will be used as a context manager. 275 func: This will be called between saving flags and restoring flags. 276 overrides: Flag names mapped to their values. These flags will be set after 277 saving the original flag state. The type of the values depends on if 278 _FlagOverrider or _ParsingFlagOverrider was specified. 279 280 Returns: 281 A wrapped version of func. 282 """ 283 284 @functools.wraps(func) 285 def _flagsaver_wrapper(*args, **kwargs): 286 """Wrapper function that saves and restores flags.""" 287 with flag_overrider_cls(**overrides): 288 return func(*args, **kwargs) 289 290 return _flagsaver_wrapper 291 292 293class _FlagOverrider(object): 294 """Overrides flags for the duration of the decorated function call. 295 296 It also restores all original values of flags after decorated method 297 completes. 298 """ 299 300 def __init__(self, **overrides: Any): 301 self._overrides = overrides 302 self._saved_flag_values = None 303 304 def __call__(self, func: _CallableT) -> _CallableT: 305 if inspect.isclass(func): 306 raise TypeError('flagsaver cannot be applied to a class.') 307 return _wrap(self.__class__, func, self._overrides) 308 309 def __enter__(self): 310 self._saved_flag_values = save_flag_values(FLAGS) 311 try: 312 FLAGS._set_attributes(**self._overrides) 313 except: 314 # It may fail because of flag validators. 315 restore_flag_values(self._saved_flag_values, FLAGS) 316 raise 317 318 def __exit__(self, exc_type, exc_value, traceback): 319 restore_flag_values(self._saved_flag_values, FLAGS) 320 321 322class _ParsingFlagOverrider(_FlagOverrider): 323 """Context manager for overriding flags. 324 325 Simulates command line parsing. 326 327 This is simlar to _FlagOverrider except that all **overrides should be 328 strings or sequences of strings, and when context is entered this class calls 329 .parse(value) 330 331 This results in the flags having .present set properly. 332 """ 333 334 def __init__(self, **overrides: Union[str, Sequence[str]]): 335 for flag_name, new_value in overrides.items(): 336 if isinstance(new_value, str): 337 continue 338 if (isinstance(new_value, collections.abc.Sequence) and 339 all(isinstance(single_value, str) for single_value in new_value)): 340 continue 341 raise TypeError( 342 f'flagsaver.as_parsed() cannot parse {flag_name}. Expected a single ' 343 f'string or sequence of strings but {type(new_value)} was provided.') 344 super().__init__(**overrides) 345 346 def __enter__(self): 347 self._saved_flag_values = save_flag_values(FLAGS) 348 try: 349 for flag_name, unparsed_value in self._overrides.items(): 350 # LINT.IfChange(flag_override_parsing) 351 FLAGS[flag_name].parse(unparsed_value) 352 FLAGS[flag_name].using_default_value = False 353 # LINT.ThenChange() 354 355 # Perform the validation on all modified flags. This is something that 356 # FLAGS._set_attributes() does for you in _FlagOverrider. 357 for flag_name in self._overrides: 358 FLAGS._assert_validators(FLAGS[flag_name].validators) 359 360 except KeyError as e: 361 # If a flag doesn't exist, an UnrecognizedFlagError is more specific. 362 restore_flag_values(self._saved_flag_values, FLAGS) 363 raise flags.UnrecognizedFlagError('Unknown command line flag.') from e 364 365 except: 366 # It may fail because of flag validators or general parsing issues. 367 restore_flag_values(self._saved_flag_values, FLAGS) 368 raise 369 370 371def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]: 372 """Returns a copy of the flag object's ``__dict__``. 373 374 It's mostly a shallow copy of the ``__dict__``, except it also does a shallow 375 copy of the validator list. 376 377 Args: 378 flag: flags.Flag, the flag to copy. 379 380 Returns: 381 A copy of the flag object's ``__dict__``. 382 """ 383 copy = flag.__dict__.copy() 384 copy['_value'] = flag.value # Ensure correct restore for C++ flags. 385 copy['validators'] = list(flag.validators) 386 return copy 387