xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/combinations.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module customizes `test_combinations` for `tf.distribute.Strategy`.
16
17Additionally it provides `generate()`, `combine()` and `times()` with
18`tf.distribute.Strategy` customizations as a default.
19"""
20
21import collections
22import copy
23import re
24import sys
25import types
26import unittest
27
28from absl import app
29import six
30
31
32from tensorflow.python.client import session
33from tensorflow.python.distribute import collective_all_reduce_strategy
34from tensorflow.python.distribute import distribute_lib
35from tensorflow.python.distribute import multi_process_runner
36from tensorflow.python.distribute import multi_worker_test_base
37from tensorflow.python.eager import context
38from tensorflow.python.eager import def_function
39from tensorflow.python.framework import combinations as framework_combinations
40from tensorflow.python.framework import config
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import test_combinations as combinations_lib
43from tensorflow.python.framework import test_util
44from tensorflow.python.platform import flags
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.util import tf_decorator
47from tensorflow.python.util import tf_inspect
48from tensorflow.python.util.tf_export import tf_export
49
50
51# TODO(rchao): Rename `distribution` parameter to `strategy` or
52# `distribute_strategy` in all tests.
53class DistributionParameter(combinations_lib.ParameterModifier):
54  """Transforms arguments of type `NamedDistribution`.
55
56  Convert all arguments of type `NamedDistribution` to the value of their
57  `strategy` property.
58  """
59
60  def modified_arguments(self, kwargs, requested_parameters):
61    # Get the parameter that indicates if we need to set the `_use_policy` flag
62    # on the strategy object. This is a temporary flag for testing the variable
63    # policy rollout.
64    use_var_policy = kwargs.get("use_var_policy", None)
65    distribution_arguments = {}
66    for k, v in kwargs.items():
67      if isinstance(v, NamedDistribution):
68        strategy = v.strategy
69        if use_var_policy:
70          strategy.extended._use_var_policy = use_var_policy
71        distribution_arguments[k] = strategy
72    return distribution_arguments
73
74
75class ClusterParameters(combinations_lib.ParameterModifier):
76  """Adds cluster parameters if a `NamedDistribution` has it.
77
78  It needs to be before DistributionParameter.
79  """
80
81  def modified_arguments(self, kwargs, requested_parameters):
82    strategy = None
83    for _, v in kwargs.items():
84      if isinstance(v, NamedDistribution):
85        if strategy is not None and _num_total_workers(v.has_chief,
86                                                       v.num_workers) > 1:
87          raise ValueError("Only support one NamedDistribution for multi worker"
88                           "tests.")
89        strategy = v
90
91    if strategy:
92      has_chief = strategy.has_chief
93      num_workers = strategy.num_workers
94      runner = strategy.runner
95      share_gpu = strategy.share_gpu
96      num_ps = strategy.num_ps
97      if "has_chief" in kwargs and kwargs["has_chief"] != has_chief:
98        raise ValueError(
99            "both has_chief and strategy specified but are not compatible")
100      if "num_workers" in kwargs and kwargs["num_workers"] != num_workers:
101        raise ValueError(
102            "both num_workers and strategy specified but are not compatible")
103    else:
104      has_chief = kwargs.get("has_chief", False)
105      num_workers = kwargs.get("num_workers", 1)
106      runner = kwargs.get("runner", None)
107      share_gpu = kwargs.get("share_gpu", True)
108      num_ps = kwargs.get("num_ps", 0)
109
110    # Always set cluster parameters if they're requested. So that generate()
111    # works when there's no startegy in the combinations.
112    update = {}
113    if "has_chief" in requested_parameters:
114      update["has_chief"] = has_chief
115    if "num_workers" in requested_parameters:
116      update["num_workers"] = num_workers
117    if "runner" in requested_parameters:
118      update["runner"] = runner
119    if "share_gpu" in requested_parameters:
120      update["share_gpu"] = share_gpu
121    if "num_ps" in requested_parameters:
122      update["num_ps"] = num_ps
123    return update
124
125
126class DistributionCombination(combinations_lib.TestCombination):
127  """Sets up distribution strategy for tests."""
128
129  def should_execute_combination(self, kwargs):
130    distributions = [
131        v for v in kwargs.values() if isinstance(v, NamedDistribution)
132    ]
133    if test_util.is_xla_enabled() and any(d.no_xla for d in distributions):
134      return (
135          False,
136          "n/a: skipping strategy combination with no_xla=True in XLA tests")
137    return (True, None)
138
139  def parameter_modifiers(self):
140    return [
141        DistributionParameter(),
142        combinations_lib.OptionalParameter("use_var_policy"),
143    ]
144
145
146class ClusterCombination(combinations_lib.TestCombination):
147  """Sets up multi worker tests."""
148
149  def parameter_modifiers(self):
150    return [ClusterParameters()]
151
152
153class GPUCombination(combinations_lib.TestCombination):
154  """Enable tests to request GPU hardware and skip non-GPU combinations.
155
156  This class expects test_combinations to be generated with `NamedDistribution`
157  wrapping instances of `tf.distribute.Strategy`.
158
159  Optionally, the `required_gpus` argument is supported.  GPU hardware is
160  required, if its value is `True` or > 0.
161
162  Attributes:
163    GPU_TEST: The environment is considered to have GPU hardware available if
164              the name of the program contains "test_gpu" or "test_xla_gpu".
165  """
166  GPU_TEST = False
167  if sys.argv:
168    GPU_TEST = re.search(r"(test_2?gpu|test_xla_2?gpu)$", sys.argv[0])
169
170  def should_execute_combination(self, kwargs):
171    distributions = [
172        v for v in kwargs.values() if isinstance(v, NamedDistribution)
173    ]
174    required_gpus = kwargs.get("required_gpus", 0)
175    required_physical_gpus = kwargs.get("required_physical_gpus", 0)
176
177    if distributions and required_gpus:
178      raise ValueError("Do not use `required_gpus` and arguments of type "
179                       "NamedDistribution together.")
180
181    number_of_required_gpus = max(
182        [required_gpus] + [required_physical_gpus] +
183        [d.required_physical_gpus or 0 for d in distributions] +
184        [d.required_gpus or 0 for d in distributions])
185    number_of_required_physical_gpus = max(
186        [required_physical_gpus] +
187        [d.required_physical_gpus or 0 for d in distributions])
188
189    if (required_physical_gpus and required_gpus):
190      raise ValueError("Only one of `required_physical_gpus`(number of physical"
191                       " GPUs required) and `required_gpus`(total number of "
192                       "GPUs required) should be set. ")
193    if not number_of_required_gpus and GPUCombination.GPU_TEST:
194      return (False, "Test that doesn't require GPUs.")
195    elif (number_of_required_gpus > 0
196          and context.num_gpus() < number_of_required_gpus):
197      return (False, ("Only {} of {} required GPUs are available.".format(
198          context.num_gpus(), number_of_required_gpus)))
199    elif number_of_required_physical_gpus > len(
200        config.list_physical_devices("GPU")):
201      return (False,
202              ("Only {} of {} required physical GPUs are available.".format(
203                  config.list_physical_devices("GPU"), required_physical_gpus)))
204    else:
205      return (True, None)
206
207  def parameter_modifiers(self):
208    return [combinations_lib.OptionalParameter("required_gpus"),
209            combinations_lib.OptionalParameter("required_physical_gpus")]
210
211
212class TPUCombination(combinations_lib.TestCombination):
213  """Allow to request TPU hardware and skip non-TPU combinations.
214
215  This class expects test_combinations to be generated with `NamedDistribution`
216  wrapping instances of `tf.distribute.Strategy`.
217
218  Optionally, the `required_tpus` parameter is supported.  TPU hardware is
219  required, if its argument is `True` or > 0.
220
221  Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is
222  required by `required_tpus`, it specifically must be a Cloud TPU (specified
223  with `--tpu`) if `use_cloud_tpu` is `True`.
224
225  Attributes:
226    TPU_TEST: The environment is considered to have TPU hardware available if
227              the name of the program contains "test_tpu".
228  """
229
230  TPU_TEST = False
231  if sys.argv:
232    TPU_TEST = "test_tpu" in sys.argv[0]
233
234  def should_execute_combination(self, kwargs):
235    distributions = [
236        v for v in kwargs.values() if isinstance(v, NamedDistribution)
237    ]
238    # TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor
239    # of 'required_tpus'.
240    if "required_tpus" in kwargs and "required_tpu" in kwargs:
241      raise ValueError("Do not use `required_tpu`.  Both `required_tpus` and "
242                       "`required_tpu` were specified.")
243    required_tpus = kwargs.get("required_tpus", None) or kwargs.get(
244        "required_tpu", None)
245
246    if distributions and required_tpus:
247      raise ValueError("Do not use `required_tpus` and arguments of type "
248                       "NamedDistribution together.")
249
250    # TODO(isaprykin): Add support for a particular number of TPUs.  Right now
251    # it's binary.
252    number_of_required_tpus = max([required_tpus or 0] +
253                                  [d.required_tpu or 0 for d in distributions])
254    use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] +
255                        [d.use_cloud_tpu for d in distributions])
256    tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or ""
257
258    if not number_of_required_tpus and TPUCombination.TPU_TEST:
259      return (False, "Test that doesn't require TPUs.")
260    if number_of_required_tpus and not TPUCombination.TPU_TEST:
261      return (False, "Test requires a TPU, but it's not available.")
262    if use_cloud_tpu and not tpu:
263      return (False, "Test requires a Cloud TPU, but none specified.")
264    if not use_cloud_tpu and tpu:
265      return (False, "Test requires local TPU, but Cloud TPU specified.")
266    return (True, None)
267
268  def parameter_modifiers(self):
269    return [
270        combinations_lib.OptionalParameter("required_tpus"),
271        combinations_lib.OptionalParameter("required_tpu"),
272        combinations_lib.OptionalParameter("use_cloud_tpu"),
273    ]
274
275
276class NamedDistribution(object):
277  """Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
278
279  def __init__(self,
280               name,
281               distribution_fn,
282               required_gpus=None,
283               required_physical_gpus=0,
284               required_tpu=False,
285               use_cloud_tpu=False,
286               has_chief=False,
287               num_workers=1,
288               num_ps=0,
289               share_gpu=True,
290               pool_runner_fn=None,
291               no_xla=False):
292    """Initialize NamedDistribution.
293
294    Args:
295      name: Name that will be a part of the name of the test case.
296      distribution_fn: A callable that creates a `tf.distribute.Strategy`.
297      required_gpus: The number of GPUs that the strategy requires. Only one of
298      `required_gpus` and `required_physical_gpus` should be set.
299      required_physical_gpus: Number of physical GPUs required. Only one of
300      `required_gpus` and `required_physical_gpus` should be set.
301      required_tpu: Whether the strategy requires TPU.
302      use_cloud_tpu: Whether the strategy requires cloud TPU.
303      has_chief: Whether the strategy requires a chief worker.
304      num_workers: The number of workers that the strategy requires.
305      num_ps: The number of parameter servers.
306      share_gpu: Whether to share GPUs among workers.
307      pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner
308        to run the test.
309      no_xla: Whether to skip in XLA tests.
310    """
311    object.__init__(self)
312    self._name = name
313    self._distribution_fn = distribution_fn
314    self.required_gpus = required_gpus
315    self.required_physical_gpus = required_physical_gpus
316    self.required_tpu = required_tpu
317    self.use_cloud_tpu = use_cloud_tpu
318    self.has_chief = has_chief
319    self.num_workers = num_workers
320    self.num_ps = num_ps
321    self.share_gpu = share_gpu
322    self._pool_runner_fn = pool_runner_fn
323    self.no_xla = no_xla
324
325  @property
326  def runner(self):
327    if self._pool_runner_fn is not None:
328      return self._pool_runner_fn()
329    return None
330
331  @property
332  def strategy(self):
333    return self._distribution_fn()
334
335  def __repr__(self):
336    return self._name
337
338
339# This is to allow adding combinations that runs a function both as a
340# tf.function and eagerly.
341#
342# @combinations.generate(
343#   combinations.combine(
344#     tf_function = [combinations.tf_function, combinations.no_tf_function]
345#   )
346# )
347# def testXXX(tf_function):
348#   @tf_function
349#   def foo():
350#     tf.add(1., 1.)
351#
352#   foo()
353tf_function = combinations_lib.NamedObject("TfFunction", def_function.function)
354no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f)
355
356
357def concat(*combined):
358  """Concats combinations."""
359  result = []
360  for one in combined:
361    result += one
362  return result
363
364
365@tf_export("__internal__.distribute.combinations.generate", v1=[])
366def generate(combinations, test_combinations=()):
367  # pylint: disable=g-doc-args,g-doc-return-or-yield
368  """Distributed adapter of `tf.__internal__.test.combinations.generate`.
369
370  All tests with distributed strategy should use this one instead of
371  `tf.__internal__.test.combinations.generate`. This function has support of
372  strategy combinations, GPU/TPU and multi worker support.
373
374  See `tf.__internal__.test.combinations.generate` for usage.
375  """
376  # pylint: enable=g-doc-args,g-doc-return-or-yield
377  default_combinations = (
378      framework_combinations.EagerGraphCombination(),
379      framework_combinations.TFVersionCombination(),
380      ClusterCombination(),
381      DistributionCombination(),
382      GPUCombination(),
383      TPUCombination(),
384  )
385  # We apply our own decoration to handle multi worker tests before applying
386  # framework.test_combinations.generate. The order is important since we need
387  # framework.test_combinations.generate to apply all parameter modifiers first.
388  combination_decorator = combinations_lib.generate(
389      combinations, test_combinations=default_combinations + test_combinations)
390
391  def decorator(test_method_or_class):
392    if isinstance(test_method_or_class, type):
393      # If it's a test class.
394      class_object = test_method_or_class
395      # Decorate each test method with _multi_worker_test.
396      for name, test_method in six.iteritems(class_object.__dict__.copy()):
397        if (name.startswith(unittest.TestLoader.testMethodPrefix) and
398            isinstance(test_method, types.FunctionType)):
399          setattr(class_object, name, _multi_worker_test(test_method))
400      return combination_decorator(class_object)
401    else:
402      return combination_decorator(_multi_worker_test(test_method_or_class))
403
404  return decorator
405
406
407combine = combinations_lib.combine
408times = combinations_lib.times
409NamedObject = combinations_lib.NamedObject
410
411
412# Identifies whether we're in the main process or worker processes.
413# `_multi_worker_test` decoration behaves differently in the main processs and
414# the worker processes. See the documentation of _multi_worker_test for detail.
415_running_in_worker = False
416
417
418@tf_export("__internal__.distribute.combinations.in_main_process", v1=[])
419def in_main_process():
420  """Whether it's in the main test process.
421
422  This is normally used to prepare the test environment which should only happen
423  in the main process.
424
425  Returns:
426    A boolean.
427  """
428  return not _running_in_worker
429
430
431class TestEnvironment(object):
432  """Holds the test environment information.
433
434  Tests should modify the attributes of the instance returned by `env()` in the
435  main process if needed, and it will be passed to the worker processes each
436  time a test case is run.
437  """
438
439  def __init__(self):
440    self.tf_data_service_dispatcher = None
441    # Note that this includes GPUs that may not be visible to the current
442    # worker.
443    self.total_phsyical_gpus = None
444
445  def __setattr__(self, name, value):
446    if not in_main_process():
447      raise ValueError(
448          "combinations.env() should only be modified in the main process. "
449          "Condition your code on combinations.in_main_process().")
450    super().__setattr__(name, value)
451
452
453_env = TestEnvironment()
454
455
456@tf_export("__internal__.distribute.combinations.env", v1=[])
457def env():
458  """Returns the object holds the test environment information.
459
460  Tests should modify this in the main process if needed, and it will be passed
461  to the worker processes each time a test case is run.
462
463  Returns:
464    a TestEnvironment object.
465  """
466  return _env
467
468
469def _set_total_phsyical_gpus():
470  if in_main_process():
471    env().total_phsyical_gpus = len(
472        context.context().list_physical_devices("GPU"))
473
474
475# This is needed in case CUDA is lazily loaded.
476app.call_after_init(_set_total_phsyical_gpus)
477
478
479_TestResult = collections.namedtuple("_TestResult", ["status", "message"])
480
481
482def _test_runner(test_id, test_env):
483  """Executes the test with the given test_id.
484
485  This is a simple wrapper around TestRunner to be used with
486  multi_process_runner. Similar to test.main(), but it executes only one test
487  specified by test_id and returns whether the test succeeds. If the test fails,
488  the function prints failures and errors to stdout.
489
490  Args:
491    test_id: TestCase.id()
492    test_env: a TestEnvironment object.
493
494  Returns:
495    A boolean indicates whether the test succeeds.
496  """
497  global _running_in_worker, _env
498  # No need to restore the value of _running_in_worker since it should always be
499  # True in worker processes.
500  _running_in_worker = True
501  _env = test_env
502  test = unittest.defaultTestLoader.loadTestsFromName(test_id)
503  runner = unittest.TextTestRunner()
504  result = runner.run(test)
505  # Treat expected failures as failures, so that the main process can get
506  # them and fail as expected. Also treat errors as failures to simplify the
507  # handling.
508  failures = result.failures + result.expectedFailures + result.errors
509  if failures:
510    ret = _TestResult(status="failure", message=failures[0][1])
511  elif result.skipped:
512    ret = _TestResult(status="skipped", message=result.skipped[0][1])
513  else:
514    # Treat unexpectedSuccesses as OK so that the test case in the main process
515    # succeed as well.
516    ret = _TestResult(status="ok", message=None)
517  # Print tracebacks to stdout and multi_process_runner will collect
518  # them and stream back to the main process.
519  if ret.message:
520    print(ret.message)
521  return ret
522
523
524def _multi_worker_test(test_method):
525  """Decorate test_method so that it runs in each worker.
526
527  We use `multi_process_runner` to simulate multiple workers. Since we run the
528  this function in the main process and all worker processes, this decoration
529  behaves differently in the main process and worker procssses. In the main
530  process, it spawns subprocesses and runs the test on each of them; in a worker
531  process, it executes test in the same way as a normal test, e.g.
532  setUp()/tearDown() are called before/after the test.
533
534  Args:
535    test_method: a function which must be a test method.
536
537  Returns:
538    Decorated `test_method`. Note that the decorated function has additional
539    arguments.
540  """
541
542  def decorator(self, has_chief, num_workers, num_ps, share_gpu, runner,
543                **kwargs):
544    if _num_total_workers(has_chief,
545                          num_workers) == 1 or _running_in_worker or (
546                              # Use in-process cluster for PS combinations
547                              # when XLA is enabled.
548                              test_util.is_xla_enabled() and num_ps > 0):
549      # We're in worker process or the test is for single worker. Either case we
550      # execute the test method directly instead of spawning subprocesses.
551
552      # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a
553      # session that connects to the local server. This is necessary for multi
554      # worker graph mode tests to work. Those tests cannot use their graphs or
555      # sessions, including the one returned by self.cached_session(). Since
556      # existing tests may already be doing so, we only install the session for
557      # multi worker tests.
558      with _multi_worker_session(kwargs):
559        test_method(self, **kwargs)
560      return
561
562    # We're in the main process. We spawn subprocesses and run the *test* on
563    # each of them. Note that we're not directly executing test_method passed to
564    # _multi_worker_test, because we need setUp()/tearDown() to be called and
565    # all the decorations on the test method. The conceptual call stack is:
566    #   [main process]test.main()
567    #     [main process]test_runner.run(test)
568    #       [main process]wrapper by combinations.generate()
569    #         [main process]_multi_worker_test.decorator()
570    #           # A sub process goes through the same code path as the main
571    #           # process.
572    #           [sub process]_test_runner()
573    #             [sub process]test_runner.run(test)
574    #               [sub process]wrapper by combinations.generate()
575    #                 [sub process]_multi_worker_test.decorator()
576    #                   # _running_in_worker is True
577    #                   [sub process]test_method()
578    test_id = self.id()
579    if runner:
580      results = runner.run(_test_runner, args=(test_id, _env))
581    else:
582      cluster_spec = multi_worker_test_base.create_cluster_spec(
583          has_chief=has_chief,
584          num_workers=num_workers,
585          num_ps=num_ps,
586          has_eval=False)
587      ephemeral_runner = multi_process_runner.MultiProcessRunner(
588          _test_runner,
589          cluster_spec,
590          share_gpu=share_gpu,
591          args=(test_id, _env),
592          dependence_on_chief=has_chief)
593      ephemeral_runner.start()
594      results = ephemeral_runner.join().return_value
595
596    skip_reason = None
597    for result in results:
598      if result.status == "failure":
599        # We can't tell which worker the return value come from, so we fail on
600        # the  first error.
601        self.fail(result.message)
602        break
603      elif result.status == "skipped":
604        # Record the skip reason, but do not actually skip the test in case some
605        # processes fail instead.
606        skip_reason = result.message
607    if skip_reason is not None:
608      self.skipTest(skip_reason)
609
610  argspec = tf_inspect.getfullargspec(test_method)
611  decorator_args = (argspec.args or []) + [
612      "has_chief", "num_workers", "num_ps", "share_gpu", "runner"
613  ]
614  decorator_argspec = argspec._replace(args=decorator_args)
615  return tf_decorator.make_decorator(
616      test_method, decorator, decorator_argspec=decorator_argspec)
617
618
619def _num_total_workers(has_chief, num_workers):
620  """Returns the number of workers including the chief."""
621  if has_chief:
622    return num_workers + 1
623  return num_workers
624
625
626def _multi_worker_session(kwargs):
627  """Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy.
628
629  Args:
630    kwargs: a dict. Keyword arguments passed to the test.
631
632  Returns:
633    A context manager. If MultiWorkerMirroredStrategy is the  one and only one
634    strategy in kwargs and it's in graph mode, it's the seesion that is
635    configured for that strategy.  Otherwise, it's a no-op context manager.
636  """
637  strategy = None
638  for _, v in kwargs.items():
639    if isinstance(v, distribute_lib.StrategyBase):
640      if strategy is not None:
641        logging.warning(
642            "The test uses multiple strategies. Skipping "
643            "entering a session that is configured for the strategy.")
644        return ops.NullContextmanager()
645      strategy = v
646  if context.executing_eagerly() or not isinstance(
647      strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy):
648    return ops.NullContextmanager()
649  sess_config = copy.deepcopy(context.context().config)
650  sess_config = strategy.update_config_proto(sess_config)
651  target = strategy.cluster_resolver.master()
652  return session.Session(config=sess_config, target=target).as_default()
653