xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/test_combinations.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Facilities for creating multiple test combinations.
16
17Here is a simple example for testing various optimizers in Eager and Graph:
18
19class AdditionExample(test.TestCase, parameterized.TestCase):
20  @combinations.generate(
21     combinations.combine(mode=["graph", "eager"],
22                          optimizer=[AdamOptimizer(),
23                                     GradientDescentOptimizer()]))
24  def testOptimizer(self, optimizer):
25    ... f(optimizer)...
26
27This will run `testOptimizer` 4 times with the specified optimizers: 2 in
28Eager and 2 in Graph mode.
29The test is going to accept the same parameters as the ones used in `combine()`.
30The parameters need to match by name between the `combine()` call and the test
31signature.  It is necessary to accept all parameters. See `OptionalParameter`
32for a way to implement optional parameters.
33
34`combine()` function is available for creating a cross product of various
35options.  `times()` function exists for creating a product of N `combine()`-ed
36results.
37
38The execution of generated tests can be customized in a number of ways:
39-  The test can be skipped if it is not running in the correct environment.
40-  The arguments that are passed to the test can be additionally transformed.
41-  The test can be run with specific Python context managers.
42These behaviors can be customized by providing instances of `TestCombination` to
43`generate()`.
44"""
45
46from collections import OrderedDict
47import contextlib
48import re
49import types
50import unittest
51
52from absl.testing import parameterized
53
54from tensorflow.python.util import tf_inspect
55from tensorflow.python.util.tf_export import tf_export
56
57
58@tf_export("__internal__.test.combinations.TestCombination", v1=[])
59class TestCombination:
60  """Customize the behavior of `generate()` and the tests that it executes.
61
62  Here is sequence of steps for executing a test combination:
63    1. The test combination is evaluated for whether it should be executed in
64       the given environment by calling `should_execute_combination`.
65    2. If the test combination is going to be executed, then the arguments for
66       all combined parameters are validated.  Some arguments can be handled in
67       a special way.  This is achieved by implementing that logic in
68       `ParameterModifier` instances that returned from `parameter_modifiers`.
69    3. Before executing the test, `context_managers` are installed
70       around it.
71  """
72
73  def should_execute_combination(self, kwargs):
74    """Indicates whether the combination of test arguments should be executed.
75
76    If the environment doesn't satisfy the dependencies of the test
77    combination, then it can be skipped.
78
79    Args:
80      kwargs:  Arguments that are passed to the test combination.
81
82    Returns:
83      A tuple boolean and an optional string.  The boolean False indicates
84    that the test should be skipped.  The string would indicate a textual
85    description of the reason.  If the test is going to be executed, then
86    this method returns `None` instead of the string.
87    """
88    del kwargs
89    return (True, None)
90
91  def parameter_modifiers(self):
92    """Returns `ParameterModifier` instances that customize the arguments."""
93    return []
94
95  def context_managers(self, kwargs):
96    """Return context managers for running the test combination.
97
98    The test combination will run under all context managers that all
99    `TestCombination` instances return.
100
101    Args:
102      kwargs:  Arguments and their values that are passed to the test
103        combination.
104
105    Returns:
106      A list of instantiated context managers.
107    """
108    del kwargs
109    return []
110
111
112@tf_export("__internal__.test.combinations.ParameterModifier", v1=[])
113class ParameterModifier:
114  """Customizes the behavior of a particular parameter.
115
116  Users should override `modified_arguments()` to modify the parameter they
117  want, eg: change the value of certain parameter or filter it from the params
118  passed to the test case.
119
120  See the sample usage below, it will change any negative parameters to zero
121  before it gets passed to test case.
122  ```
123  class NonNegativeParameterModifier(ParameterModifier):
124
125    def modified_arguments(self, kwargs, requested_parameters):
126      updates = {}
127      for name, value in kwargs.items():
128        if value < 0:
129          updates[name] = 0
130      return updates
131  ```
132  """
133
134  DO_NOT_PASS_TO_THE_TEST = object()
135
136  def __init__(self, parameter_name=None):
137    """Construct a parameter modifier that may be specific to a parameter.
138
139    Args:
140      parameter_name:  A `ParameterModifier` instance may operate on a class of
141        parameters or on a parameter with a particular name.  Only
142        `ParameterModifier` instances that are of a unique type or were
143        initialized with a unique `parameter_name` will be executed.
144        See `__eq__` and `__hash__`.
145    """
146    self._parameter_name = parameter_name
147
148  def modified_arguments(self, kwargs, requested_parameters):
149    """Replace user-provided arguments before they are passed to a test.
150
151    This makes it possible to adjust user-provided arguments before passing
152    them to the test method.
153
154    Args:
155      kwargs:  The combined arguments for the test.
156      requested_parameters: The set of parameters that are defined in the
157        signature of the test method.
158
159    Returns:
160      A dictionary with updates to `kwargs`.  Keys with values set to
161      `ParameterModifier.DO_NOT_PASS_TO_THE_TEST` are going to be deleted and
162      not passed to the test.
163    """
164    del kwargs, requested_parameters
165    return {}
166
167  def __eq__(self, other):
168    """Compare `ParameterModifier` by type and `parameter_name`."""
169    if self is other:
170      return True
171    elif type(self) is type(other):
172      return self._parameter_name == other._parameter_name
173    else:
174      return False
175
176  def __ne__(self, other):
177    return not self.__eq__(other)
178
179  def __hash__(self):
180    """Compare `ParameterModifier` by type or `parameter_name`."""
181    if self._parameter_name:
182      return hash(self._parameter_name)
183    else:
184      return id(self.__class__)
185
186
187@tf_export("__internal__.test.combinations.OptionalParameter", v1=[])
188class OptionalParameter(ParameterModifier):
189  """A parameter that is optional in `combine()` and in the test signature.
190
191  `OptionalParameter` is usually used with `TestCombination` in the
192  `parameter_modifiers()`. It allows `TestCombination` to skip certain
193  parameters when passing them to `combine()`, since the `TestCombination` might
194  consume the param and create some context based on the value it gets.
195
196  See the sample usage below:
197
198  ```
199  class EagerGraphCombination(TestCombination):
200
201    def context_managers(self, kwargs):
202      mode = kwargs.pop("mode", None)
203      if mode is None:
204        return []
205      elif mode == "eager":
206        return [context.eager_mode()]
207      elif mode == "graph":
208        return [ops.Graph().as_default(), context.graph_mode()]
209      else:
210        raise ValueError(
211            "'mode' has to be either 'eager' or 'graph', got {}".format(mode))
212
213    def parameter_modifiers(self):
214      return [test_combinations.OptionalParameter("mode")]
215  ```
216
217  When the test case is generated, the param "mode" will not be passed to the
218  test method, since it is consumed by the `EagerGraphCombination`.
219  """
220
221  def modified_arguments(self, kwargs, requested_parameters):
222    if self._parameter_name in requested_parameters:
223      return {}
224    else:
225      return {self._parameter_name: ParameterModifier.DO_NOT_PASS_TO_THE_TEST}
226
227
228def generate(combinations, test_combinations=()):
229  """A decorator for generating combinations of a test method or a test class.
230
231  Parameters of the test method must match by name to get the corresponding
232  value of the combination.  Tests must accept all parameters that are passed
233  other than the ones that are `OptionalParameter`.
234
235  Args:
236    combinations: a list of dictionaries created using combine() and times().
237    test_combinations: a tuple of `TestCombination` instances that customize
238      the execution of generated tests.
239
240  Returns:
241    a decorator that will cause the test method or the test class to be run
242    under the specified conditions.
243
244  Raises:
245    ValueError: if any parameters were not accepted by the test method
246  """
247  def decorator(test_method_or_class):
248    """The decorator to be returned."""
249
250    # Generate good test names that can be used with --test_filter.
251    named_combinations = []
252    for combination in combinations:
253      # We use OrderedDicts in `combine()` and `times()` to ensure stable
254      # order of keys in each dictionary.
255      assert isinstance(combination, OrderedDict)
256      name = "".join([
257          "_{}_{}".format("".join(filter(str.isalnum, key)),
258                          "".join(filter(str.isalnum, _get_name(value, i))))
259          for i, (key, value) in enumerate(combination.items())
260      ])
261      named_combinations.append(
262          OrderedDict(
263              list(combination.items()) +
264              [("testcase_name", "_test{}".format(name))]))
265
266    if isinstance(test_method_or_class, type):
267      class_object = test_method_or_class
268      class_object._test_method_ids = test_method_ids = {}
269      for name, test_method in class_object.__dict__.copy().items():
270        if (name.startswith(unittest.TestLoader.testMethodPrefix) and
271            isinstance(test_method, types.FunctionType)):
272          delattr(class_object, name)
273          methods = {}
274          parameterized._update_class_dict_for_param_test_case(
275              class_object.__name__, methods, test_method_ids, name,
276              parameterized._ParameterizedTestIter(
277                  _augment_with_special_arguments(
278                      test_method, test_combinations=test_combinations),
279                  named_combinations, parameterized._NAMED, name))
280          for method_name, method in methods.items():
281            setattr(class_object, method_name, method)
282
283      return class_object
284    else:
285      test_method = _augment_with_special_arguments(
286          test_method_or_class, test_combinations=test_combinations)
287      return parameterized.named_parameters(*named_combinations)(test_method)
288
289  return decorator
290
291
292def _augment_with_special_arguments(test_method, test_combinations):
293  def decorated(self, **kwargs):
294    """A wrapped test method that can treat some arguments in a special way."""
295    original_kwargs = kwargs.copy()
296
297    # Skip combinations that are going to be executed in a different testing
298    # environment.
299    reasons_to_skip = []
300    for combination in test_combinations:
301      should_execute, reason = combination.should_execute_combination(
302          original_kwargs.copy())
303      if not should_execute:
304        reasons_to_skip.append(" - " + reason)
305
306    if reasons_to_skip:
307      self.skipTest("\n".join(reasons_to_skip))
308
309    customized_parameters = []
310    for combination in test_combinations:
311      customized_parameters.extend(combination.parameter_modifiers())
312    customized_parameters = set(customized_parameters)
313
314    # The function for running the test under the total set of
315    # `context_managers`:
316    def execute_test_method():
317      requested_parameters = tf_inspect.getfullargspec(test_method).args
318      for customized_parameter in customized_parameters:
319        for argument, value in customized_parameter.modified_arguments(
320            original_kwargs.copy(), requested_parameters).items():
321          if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
322            kwargs.pop(argument, None)
323          else:
324            kwargs[argument] = value
325
326      omitted_arguments = set(requested_parameters).difference(
327          set(list(kwargs.keys()) + ["self"]))
328      if omitted_arguments:
329        raise ValueError("The test requires parameters whose arguments "
330                         "were not passed: {} .".format(omitted_arguments))
331      missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
332          set(requested_parameters))
333      if missing_arguments:
334        raise ValueError("The test does not take parameters that were passed "
335                         ": {} .".format(missing_arguments))
336
337      kwargs_to_pass = {}
338      for parameter in requested_parameters:
339        if parameter == "self":
340          kwargs_to_pass[parameter] = self
341        else:
342          kwargs_to_pass[parameter] = kwargs[parameter]
343      test_method(**kwargs_to_pass)
344
345    # Install `context_managers` before running the test:
346    context_managers = []
347    for combination in test_combinations:
348      for manager in combination.context_managers(
349          original_kwargs.copy()):
350        context_managers.append(manager)
351
352    if hasattr(contextlib, "nested"):  # Python 2
353      # TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
354      with contextlib.nested(*context_managers):
355        execute_test_method()
356    else:  # Python 3
357      with contextlib.ExitStack() as context_stack:
358        for manager in context_managers:
359          context_stack.enter_context(manager)
360        execute_test_method()
361
362  return decorated
363
364
365@tf_export("__internal__.test.combinations.combine", v1=[])
366def combine(**kwargs):
367  """Generate combinations based on its keyword arguments.
368
369  Two sets of returned combinations can be concatenated using +.  Their product
370  can be computed using `times()`.
371
372  Args:
373    **kwargs: keyword arguments of form `option=[possibilities, ...]`
374         or `option=the_only_possibility`.
375
376  Returns:
377    a list of dictionaries for each combination. Keys in the dictionaries are
378    the keyword argument names.  Each key has one value - one of the
379    corresponding keyword argument values.
380  """
381  if not kwargs:
382    return [OrderedDict()]
383
384  sort_by_key = lambda k: k[0]
385  kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
386  first = list(kwargs.items())[0]
387
388  rest = dict(list(kwargs.items())[1:])
389  rest_combined = combine(**rest)
390
391  key = first[0]
392  values = first[1]
393  if not isinstance(values, list):
394    values = [values]
395
396  return [
397      OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
398      for v in values
399      for combined in rest_combined
400  ]
401
402
403@tf_export("__internal__.test.combinations.times", v1=[])
404def times(*combined):
405  """Generate a product of N sets of combinations.
406
407  times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4])
408
409  Args:
410    *combined: N lists of dictionaries that specify combinations.
411
412  Returns:
413    a list of dictionaries for each combination.
414
415  Raises:
416    ValueError: if some of the inputs have overlapping keys.
417  """
418  assert combined
419
420  if len(combined) == 1:
421    return combined[0]
422
423  first = combined[0]
424  rest_combined = times(*combined[1:])
425
426  combined_results = []
427  for a in first:
428    for b in rest_combined:
429      if set(a.keys()).intersection(set(b.keys())):
430        raise ValueError("Keys need to not overlap: {} vs {}".format(
431            a.keys(), b.keys()))
432
433      combined_results.append(OrderedDict(list(a.items()) + list(b.items())))
434  return combined_results
435
436
437@tf_export("__internal__.test.combinations.NamedObject", v1=[])
438class NamedObject:
439  """A class that translates an object into a good test name."""
440
441  def __init__(self, name, obj):
442    self._name = name
443    self._obj = obj
444
445  def __getattr__(self, name):
446    return getattr(self._obj, name)
447
448  def __call__(self, *args, **kwargs):
449    return self._obj(*args, **kwargs)
450
451  def __iter__(self):
452    return self._obj.__iter__()
453
454  def __repr__(self):
455    return self._name
456
457
458def _get_name(value, index):
459  return re.sub("0[xX][0-9a-fA-F]+", str(index), str(value))
460