1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Facilities for creating multiple test combinations. 16 17Here is a simple example for testing various optimizers in Eager and Graph: 18 19class AdditionExample(test.TestCase, parameterized.TestCase): 20 @combinations.generate( 21 combinations.combine(mode=["graph", "eager"], 22 optimizer=[AdamOptimizer(), 23 GradientDescentOptimizer()])) 24 def testOptimizer(self, optimizer): 25 ... f(optimizer)... 26 27This will run `testOptimizer` 4 times with the specified optimizers: 2 in 28Eager and 2 in Graph mode. 29The test is going to accept the same parameters as the ones used in `combine()`. 30The parameters need to match by name between the `combine()` call and the test 31signature. It is necessary to accept all parameters. See `OptionalParameter` 32for a way to implement optional parameters. 33 34`combine()` function is available for creating a cross product of various 35options. `times()` function exists for creating a product of N `combine()`-ed 36results. 37 38The execution of generated tests can be customized in a number of ways: 39- The test can be skipped if it is not running in the correct environment. 40- The arguments that are passed to the test can be additionally transformed. 41- The test can be run with specific Python context managers. 42These behaviors can be customized by providing instances of `TestCombination` to 43`generate()`. 44""" 45 46from collections import OrderedDict 47import contextlib 48import re 49import types 50import unittest 51 52from absl.testing import parameterized 53 54from tensorflow.python.util import tf_inspect 55from tensorflow.python.util.tf_export import tf_export 56 57 58@tf_export("__internal__.test.combinations.TestCombination", v1=[]) 59class TestCombination: 60 """Customize the behavior of `generate()` and the tests that it executes. 61 62 Here is sequence of steps for executing a test combination: 63 1. The test combination is evaluated for whether it should be executed in 64 the given environment by calling `should_execute_combination`. 65 2. If the test combination is going to be executed, then the arguments for 66 all combined parameters are validated. Some arguments can be handled in 67 a special way. This is achieved by implementing that logic in 68 `ParameterModifier` instances that returned from `parameter_modifiers`. 69 3. Before executing the test, `context_managers` are installed 70 around it. 71 """ 72 73 def should_execute_combination(self, kwargs): 74 """Indicates whether the combination of test arguments should be executed. 75 76 If the environment doesn't satisfy the dependencies of the test 77 combination, then it can be skipped. 78 79 Args: 80 kwargs: Arguments that are passed to the test combination. 81 82 Returns: 83 A tuple boolean and an optional string. The boolean False indicates 84 that the test should be skipped. The string would indicate a textual 85 description of the reason. If the test is going to be executed, then 86 this method returns `None` instead of the string. 87 """ 88 del kwargs 89 return (True, None) 90 91 def parameter_modifiers(self): 92 """Returns `ParameterModifier` instances that customize the arguments.""" 93 return [] 94 95 def context_managers(self, kwargs): 96 """Return context managers for running the test combination. 97 98 The test combination will run under all context managers that all 99 `TestCombination` instances return. 100 101 Args: 102 kwargs: Arguments and their values that are passed to the test 103 combination. 104 105 Returns: 106 A list of instantiated context managers. 107 """ 108 del kwargs 109 return [] 110 111 112@tf_export("__internal__.test.combinations.ParameterModifier", v1=[]) 113class ParameterModifier: 114 """Customizes the behavior of a particular parameter. 115 116 Users should override `modified_arguments()` to modify the parameter they 117 want, eg: change the value of certain parameter or filter it from the params 118 passed to the test case. 119 120 See the sample usage below, it will change any negative parameters to zero 121 before it gets passed to test case. 122 ``` 123 class NonNegativeParameterModifier(ParameterModifier): 124 125 def modified_arguments(self, kwargs, requested_parameters): 126 updates = {} 127 for name, value in kwargs.items(): 128 if value < 0: 129 updates[name] = 0 130 return updates 131 ``` 132 """ 133 134 DO_NOT_PASS_TO_THE_TEST = object() 135 136 def __init__(self, parameter_name=None): 137 """Construct a parameter modifier that may be specific to a parameter. 138 139 Args: 140 parameter_name: A `ParameterModifier` instance may operate on a class of 141 parameters or on a parameter with a particular name. Only 142 `ParameterModifier` instances that are of a unique type or were 143 initialized with a unique `parameter_name` will be executed. 144 See `__eq__` and `__hash__`. 145 """ 146 self._parameter_name = parameter_name 147 148 def modified_arguments(self, kwargs, requested_parameters): 149 """Replace user-provided arguments before they are passed to a test. 150 151 This makes it possible to adjust user-provided arguments before passing 152 them to the test method. 153 154 Args: 155 kwargs: The combined arguments for the test. 156 requested_parameters: The set of parameters that are defined in the 157 signature of the test method. 158 159 Returns: 160 A dictionary with updates to `kwargs`. Keys with values set to 161 `ParameterModifier.DO_NOT_PASS_TO_THE_TEST` are going to be deleted and 162 not passed to the test. 163 """ 164 del kwargs, requested_parameters 165 return {} 166 167 def __eq__(self, other): 168 """Compare `ParameterModifier` by type and `parameter_name`.""" 169 if self is other: 170 return True 171 elif type(self) is type(other): 172 return self._parameter_name == other._parameter_name 173 else: 174 return False 175 176 def __ne__(self, other): 177 return not self.__eq__(other) 178 179 def __hash__(self): 180 """Compare `ParameterModifier` by type or `parameter_name`.""" 181 if self._parameter_name: 182 return hash(self._parameter_name) 183 else: 184 return id(self.__class__) 185 186 187@tf_export("__internal__.test.combinations.OptionalParameter", v1=[]) 188class OptionalParameter(ParameterModifier): 189 """A parameter that is optional in `combine()` and in the test signature. 190 191 `OptionalParameter` is usually used with `TestCombination` in the 192 `parameter_modifiers()`. It allows `TestCombination` to skip certain 193 parameters when passing them to `combine()`, since the `TestCombination` might 194 consume the param and create some context based on the value it gets. 195 196 See the sample usage below: 197 198 ``` 199 class EagerGraphCombination(TestCombination): 200 201 def context_managers(self, kwargs): 202 mode = kwargs.pop("mode", None) 203 if mode is None: 204 return [] 205 elif mode == "eager": 206 return [context.eager_mode()] 207 elif mode == "graph": 208 return [ops.Graph().as_default(), context.graph_mode()] 209 else: 210 raise ValueError( 211 "'mode' has to be either 'eager' or 'graph', got {}".format(mode)) 212 213 def parameter_modifiers(self): 214 return [test_combinations.OptionalParameter("mode")] 215 ``` 216 217 When the test case is generated, the param "mode" will not be passed to the 218 test method, since it is consumed by the `EagerGraphCombination`. 219 """ 220 221 def modified_arguments(self, kwargs, requested_parameters): 222 if self._parameter_name in requested_parameters: 223 return {} 224 else: 225 return {self._parameter_name: ParameterModifier.DO_NOT_PASS_TO_THE_TEST} 226 227 228def generate(combinations, test_combinations=()): 229 """A decorator for generating combinations of a test method or a test class. 230 231 Parameters of the test method must match by name to get the corresponding 232 value of the combination. Tests must accept all parameters that are passed 233 other than the ones that are `OptionalParameter`. 234 235 Args: 236 combinations: a list of dictionaries created using combine() and times(). 237 test_combinations: a tuple of `TestCombination` instances that customize 238 the execution of generated tests. 239 240 Returns: 241 a decorator that will cause the test method or the test class to be run 242 under the specified conditions. 243 244 Raises: 245 ValueError: if any parameters were not accepted by the test method 246 """ 247 def decorator(test_method_or_class): 248 """The decorator to be returned.""" 249 250 # Generate good test names that can be used with --test_filter. 251 named_combinations = [] 252 for combination in combinations: 253 # We use OrderedDicts in `combine()` and `times()` to ensure stable 254 # order of keys in each dictionary. 255 assert isinstance(combination, OrderedDict) 256 name = "".join([ 257 "_{}_{}".format("".join(filter(str.isalnum, key)), 258 "".join(filter(str.isalnum, _get_name(value, i)))) 259 for i, (key, value) in enumerate(combination.items()) 260 ]) 261 named_combinations.append( 262 OrderedDict( 263 list(combination.items()) + 264 [("testcase_name", "_test{}".format(name))])) 265 266 if isinstance(test_method_or_class, type): 267 class_object = test_method_or_class 268 class_object._test_method_ids = test_method_ids = {} 269 for name, test_method in class_object.__dict__.copy().items(): 270 if (name.startswith(unittest.TestLoader.testMethodPrefix) and 271 isinstance(test_method, types.FunctionType)): 272 delattr(class_object, name) 273 methods = {} 274 parameterized._update_class_dict_for_param_test_case( 275 class_object.__name__, methods, test_method_ids, name, 276 parameterized._ParameterizedTestIter( 277 _augment_with_special_arguments( 278 test_method, test_combinations=test_combinations), 279 named_combinations, parameterized._NAMED, name)) 280 for method_name, method in methods.items(): 281 setattr(class_object, method_name, method) 282 283 return class_object 284 else: 285 test_method = _augment_with_special_arguments( 286 test_method_or_class, test_combinations=test_combinations) 287 return parameterized.named_parameters(*named_combinations)(test_method) 288 289 return decorator 290 291 292def _augment_with_special_arguments(test_method, test_combinations): 293 def decorated(self, **kwargs): 294 """A wrapped test method that can treat some arguments in a special way.""" 295 original_kwargs = kwargs.copy() 296 297 # Skip combinations that are going to be executed in a different testing 298 # environment. 299 reasons_to_skip = [] 300 for combination in test_combinations: 301 should_execute, reason = combination.should_execute_combination( 302 original_kwargs.copy()) 303 if not should_execute: 304 reasons_to_skip.append(" - " + reason) 305 306 if reasons_to_skip: 307 self.skipTest("\n".join(reasons_to_skip)) 308 309 customized_parameters = [] 310 for combination in test_combinations: 311 customized_parameters.extend(combination.parameter_modifiers()) 312 customized_parameters = set(customized_parameters) 313 314 # The function for running the test under the total set of 315 # `context_managers`: 316 def execute_test_method(): 317 requested_parameters = tf_inspect.getfullargspec(test_method).args 318 for customized_parameter in customized_parameters: 319 for argument, value in customized_parameter.modified_arguments( 320 original_kwargs.copy(), requested_parameters).items(): 321 if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST: 322 kwargs.pop(argument, None) 323 else: 324 kwargs[argument] = value 325 326 omitted_arguments = set(requested_parameters).difference( 327 set(list(kwargs.keys()) + ["self"])) 328 if omitted_arguments: 329 raise ValueError("The test requires parameters whose arguments " 330 "were not passed: {} .".format(omitted_arguments)) 331 missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( 332 set(requested_parameters)) 333 if missing_arguments: 334 raise ValueError("The test does not take parameters that were passed " 335 ": {} .".format(missing_arguments)) 336 337 kwargs_to_pass = {} 338 for parameter in requested_parameters: 339 if parameter == "self": 340 kwargs_to_pass[parameter] = self 341 else: 342 kwargs_to_pass[parameter] = kwargs[parameter] 343 test_method(**kwargs_to_pass) 344 345 # Install `context_managers` before running the test: 346 context_managers = [] 347 for combination in test_combinations: 348 for manager in combination.context_managers( 349 original_kwargs.copy()): 350 context_managers.append(manager) 351 352 if hasattr(contextlib, "nested"): # Python 2 353 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available. 354 with contextlib.nested(*context_managers): 355 execute_test_method() 356 else: # Python 3 357 with contextlib.ExitStack() as context_stack: 358 for manager in context_managers: 359 context_stack.enter_context(manager) 360 execute_test_method() 361 362 return decorated 363 364 365@tf_export("__internal__.test.combinations.combine", v1=[]) 366def combine(**kwargs): 367 """Generate combinations based on its keyword arguments. 368 369 Two sets of returned combinations can be concatenated using +. Their product 370 can be computed using `times()`. 371 372 Args: 373 **kwargs: keyword arguments of form `option=[possibilities, ...]` 374 or `option=the_only_possibility`. 375 376 Returns: 377 a list of dictionaries for each combination. Keys in the dictionaries are 378 the keyword argument names. Each key has one value - one of the 379 corresponding keyword argument values. 380 """ 381 if not kwargs: 382 return [OrderedDict()] 383 384 sort_by_key = lambda k: k[0] 385 kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) 386 first = list(kwargs.items())[0] 387 388 rest = dict(list(kwargs.items())[1:]) 389 rest_combined = combine(**rest) 390 391 key = first[0] 392 values = first[1] 393 if not isinstance(values, list): 394 values = [values] 395 396 return [ 397 OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) 398 for v in values 399 for combined in rest_combined 400 ] 401 402 403@tf_export("__internal__.test.combinations.times", v1=[]) 404def times(*combined): 405 """Generate a product of N sets of combinations. 406 407 times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4]) 408 409 Args: 410 *combined: N lists of dictionaries that specify combinations. 411 412 Returns: 413 a list of dictionaries for each combination. 414 415 Raises: 416 ValueError: if some of the inputs have overlapping keys. 417 """ 418 assert combined 419 420 if len(combined) == 1: 421 return combined[0] 422 423 first = combined[0] 424 rest_combined = times(*combined[1:]) 425 426 combined_results = [] 427 for a in first: 428 for b in rest_combined: 429 if set(a.keys()).intersection(set(b.keys())): 430 raise ValueError("Keys need to not overlap: {} vs {}".format( 431 a.keys(), b.keys())) 432 433 combined_results.append(OrderedDict(list(a.items()) + list(b.items()))) 434 return combined_results 435 436 437@tf_export("__internal__.test.combinations.NamedObject", v1=[]) 438class NamedObject: 439 """A class that translates an object into a good test name.""" 440 441 def __init__(self, name, obj): 442 self._name = name 443 self._obj = obj 444 445 def __getattr__(self, name): 446 return getattr(self._obj, name) 447 448 def __call__(self, *args, **kwargs): 449 return self._obj(*args, **kwargs) 450 451 def __iter__(self): 452 return self._obj.__iter__() 453 454 def __repr__(self): 455 return self._name 456 457 458def _get_name(value, index): 459 return re.sub("0[xX][0-9a-fA-F]+", str(index), str(value)) 460