1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module customizes `test_combinations` for `tf.distribute.Strategy`. 16 17Additionally it provides `generate()`, `combine()` and `times()` with 18`tf.distribute.Strategy` customizations as a default. 19""" 20 21import collections 22import copy 23import re 24import sys 25import types 26import unittest 27 28from absl import app 29import six 30 31 32from tensorflow.python.client import session 33from tensorflow.python.distribute import collective_all_reduce_strategy 34from tensorflow.python.distribute import distribute_lib 35from tensorflow.python.distribute import multi_process_runner 36from tensorflow.python.distribute import multi_worker_test_base 37from tensorflow.python.eager import context 38from tensorflow.python.eager import def_function 39from tensorflow.python.framework import combinations as framework_combinations 40from tensorflow.python.framework import config 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import test_combinations as combinations_lib 43from tensorflow.python.framework import test_util 44from tensorflow.python.platform import flags 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.util import tf_decorator 47from tensorflow.python.util import tf_inspect 48from tensorflow.python.util.tf_export import tf_export 49 50 51# TODO(rchao): Rename `distribution` parameter to `strategy` or 52# `distribute_strategy` in all tests. 53class DistributionParameter(combinations_lib.ParameterModifier): 54 """Transforms arguments of type `NamedDistribution`. 55 56 Convert all arguments of type `NamedDistribution` to the value of their 57 `strategy` property. 58 """ 59 60 def modified_arguments(self, kwargs, requested_parameters): 61 # Get the parameter that indicates if we need to set the `_use_policy` flag 62 # on the strategy object. This is a temporary flag for testing the variable 63 # policy rollout. 64 use_var_policy = kwargs.get("use_var_policy", None) 65 distribution_arguments = {} 66 for k, v in kwargs.items(): 67 if isinstance(v, NamedDistribution): 68 strategy = v.strategy 69 if use_var_policy: 70 strategy.extended._use_var_policy = use_var_policy 71 distribution_arguments[k] = strategy 72 return distribution_arguments 73 74 75class ClusterParameters(combinations_lib.ParameterModifier): 76 """Adds cluster parameters if a `NamedDistribution` has it. 77 78 It needs to be before DistributionParameter. 79 """ 80 81 def modified_arguments(self, kwargs, requested_parameters): 82 strategy = None 83 for _, v in kwargs.items(): 84 if isinstance(v, NamedDistribution): 85 if strategy is not None and _num_total_workers(v.has_chief, 86 v.num_workers) > 1: 87 raise ValueError("Only support one NamedDistribution for multi worker" 88 "tests.") 89 strategy = v 90 91 if strategy: 92 has_chief = strategy.has_chief 93 num_workers = strategy.num_workers 94 runner = strategy.runner 95 share_gpu = strategy.share_gpu 96 num_ps = strategy.num_ps 97 if "has_chief" in kwargs and kwargs["has_chief"] != has_chief: 98 raise ValueError( 99 "both has_chief and strategy specified but are not compatible") 100 if "num_workers" in kwargs and kwargs["num_workers"] != num_workers: 101 raise ValueError( 102 "both num_workers and strategy specified but are not compatible") 103 else: 104 has_chief = kwargs.get("has_chief", False) 105 num_workers = kwargs.get("num_workers", 1) 106 runner = kwargs.get("runner", None) 107 share_gpu = kwargs.get("share_gpu", True) 108 num_ps = kwargs.get("num_ps", 0) 109 110 # Always set cluster parameters if they're requested. So that generate() 111 # works when there's no startegy in the combinations. 112 update = {} 113 if "has_chief" in requested_parameters: 114 update["has_chief"] = has_chief 115 if "num_workers" in requested_parameters: 116 update["num_workers"] = num_workers 117 if "runner" in requested_parameters: 118 update["runner"] = runner 119 if "share_gpu" in requested_parameters: 120 update["share_gpu"] = share_gpu 121 if "num_ps" in requested_parameters: 122 update["num_ps"] = num_ps 123 return update 124 125 126class DistributionCombination(combinations_lib.TestCombination): 127 """Sets up distribution strategy for tests.""" 128 129 def should_execute_combination(self, kwargs): 130 distributions = [ 131 v for v in kwargs.values() if isinstance(v, NamedDistribution) 132 ] 133 if test_util.is_xla_enabled() and any(d.no_xla for d in distributions): 134 return ( 135 False, 136 "n/a: skipping strategy combination with no_xla=True in XLA tests") 137 return (True, None) 138 139 def parameter_modifiers(self): 140 return [ 141 DistributionParameter(), 142 combinations_lib.OptionalParameter("use_var_policy"), 143 ] 144 145 146class ClusterCombination(combinations_lib.TestCombination): 147 """Sets up multi worker tests.""" 148 149 def parameter_modifiers(self): 150 return [ClusterParameters()] 151 152 153class GPUCombination(combinations_lib.TestCombination): 154 """Enable tests to request GPU hardware and skip non-GPU combinations. 155 156 This class expects test_combinations to be generated with `NamedDistribution` 157 wrapping instances of `tf.distribute.Strategy`. 158 159 Optionally, the `required_gpus` argument is supported. GPU hardware is 160 required, if its value is `True` or > 0. 161 162 Attributes: 163 GPU_TEST: The environment is considered to have GPU hardware available if 164 the name of the program contains "test_gpu" or "test_xla_gpu". 165 """ 166 GPU_TEST = False 167 if sys.argv: 168 GPU_TEST = re.search(r"(test_2?gpu|test_xla_2?gpu)$", sys.argv[0]) 169 170 def should_execute_combination(self, kwargs): 171 distributions = [ 172 v for v in kwargs.values() if isinstance(v, NamedDistribution) 173 ] 174 required_gpus = kwargs.get("required_gpus", 0) 175 required_physical_gpus = kwargs.get("required_physical_gpus", 0) 176 177 if distributions and required_gpus: 178 raise ValueError("Do not use `required_gpus` and arguments of type " 179 "NamedDistribution together.") 180 181 number_of_required_gpus = max( 182 [required_gpus] + [required_physical_gpus] + 183 [d.required_physical_gpus or 0 for d in distributions] + 184 [d.required_gpus or 0 for d in distributions]) 185 number_of_required_physical_gpus = max( 186 [required_physical_gpus] + 187 [d.required_physical_gpus or 0 for d in distributions]) 188 189 if (required_physical_gpus and required_gpus): 190 raise ValueError("Only one of `required_physical_gpus`(number of physical" 191 " GPUs required) and `required_gpus`(total number of " 192 "GPUs required) should be set. ") 193 if not number_of_required_gpus and GPUCombination.GPU_TEST: 194 return (False, "Test that doesn't require GPUs.") 195 elif (number_of_required_gpus > 0 196 and context.num_gpus() < number_of_required_gpus): 197 return (False, ("Only {} of {} required GPUs are available.".format( 198 context.num_gpus(), number_of_required_gpus))) 199 elif number_of_required_physical_gpus > len( 200 config.list_physical_devices("GPU")): 201 return (False, 202 ("Only {} of {} required physical GPUs are available.".format( 203 config.list_physical_devices("GPU"), required_physical_gpus))) 204 else: 205 return (True, None) 206 207 def parameter_modifiers(self): 208 return [combinations_lib.OptionalParameter("required_gpus"), 209 combinations_lib.OptionalParameter("required_physical_gpus")] 210 211 212class TPUCombination(combinations_lib.TestCombination): 213 """Allow to request TPU hardware and skip non-TPU combinations. 214 215 This class expects test_combinations to be generated with `NamedDistribution` 216 wrapping instances of `tf.distribute.Strategy`. 217 218 Optionally, the `required_tpus` parameter is supported. TPU hardware is 219 required, if its argument is `True` or > 0. 220 221 Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is 222 required by `required_tpus`, it specifically must be a Cloud TPU (specified 223 with `--tpu`) if `use_cloud_tpu` is `True`. 224 225 Attributes: 226 TPU_TEST: The environment is considered to have TPU hardware available if 227 the name of the program contains "test_tpu". 228 """ 229 230 TPU_TEST = False 231 if sys.argv: 232 TPU_TEST = "test_tpu" in sys.argv[0] 233 234 def should_execute_combination(self, kwargs): 235 distributions = [ 236 v for v in kwargs.values() if isinstance(v, NamedDistribution) 237 ] 238 # TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor 239 # of 'required_tpus'. 240 if "required_tpus" in kwargs and "required_tpu" in kwargs: 241 raise ValueError("Do not use `required_tpu`. Both `required_tpus` and " 242 "`required_tpu` were specified.") 243 required_tpus = kwargs.get("required_tpus", None) or kwargs.get( 244 "required_tpu", None) 245 246 if distributions and required_tpus: 247 raise ValueError("Do not use `required_tpus` and arguments of type " 248 "NamedDistribution together.") 249 250 # TODO(isaprykin): Add support for a particular number of TPUs. Right now 251 # it's binary. 252 number_of_required_tpus = max([required_tpus or 0] + 253 [d.required_tpu or 0 for d in distributions]) 254 use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] + 255 [d.use_cloud_tpu for d in distributions]) 256 tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or "" 257 258 if not number_of_required_tpus and TPUCombination.TPU_TEST: 259 return (False, "Test that doesn't require TPUs.") 260 if number_of_required_tpus and not TPUCombination.TPU_TEST: 261 return (False, "Test requires a TPU, but it's not available.") 262 if use_cloud_tpu and not tpu: 263 return (False, "Test requires a Cloud TPU, but none specified.") 264 if not use_cloud_tpu and tpu: 265 return (False, "Test requires local TPU, but Cloud TPU specified.") 266 return (True, None) 267 268 def parameter_modifiers(self): 269 return [ 270 combinations_lib.OptionalParameter("required_tpus"), 271 combinations_lib.OptionalParameter("required_tpu"), 272 combinations_lib.OptionalParameter("use_cloud_tpu"), 273 ] 274 275 276class NamedDistribution(object): 277 """Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" 278 279 def __init__(self, 280 name, 281 distribution_fn, 282 required_gpus=None, 283 required_physical_gpus=0, 284 required_tpu=False, 285 use_cloud_tpu=False, 286 has_chief=False, 287 num_workers=1, 288 num_ps=0, 289 share_gpu=True, 290 pool_runner_fn=None, 291 no_xla=False): 292 """Initialize NamedDistribution. 293 294 Args: 295 name: Name that will be a part of the name of the test case. 296 distribution_fn: A callable that creates a `tf.distribute.Strategy`. 297 required_gpus: The number of GPUs that the strategy requires. Only one of 298 `required_gpus` and `required_physical_gpus` should be set. 299 required_physical_gpus: Number of physical GPUs required. Only one of 300 `required_gpus` and `required_physical_gpus` should be set. 301 required_tpu: Whether the strategy requires TPU. 302 use_cloud_tpu: Whether the strategy requires cloud TPU. 303 has_chief: Whether the strategy requires a chief worker. 304 num_workers: The number of workers that the strategy requires. 305 num_ps: The number of parameter servers. 306 share_gpu: Whether to share GPUs among workers. 307 pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner 308 to run the test. 309 no_xla: Whether to skip in XLA tests. 310 """ 311 object.__init__(self) 312 self._name = name 313 self._distribution_fn = distribution_fn 314 self.required_gpus = required_gpus 315 self.required_physical_gpus = required_physical_gpus 316 self.required_tpu = required_tpu 317 self.use_cloud_tpu = use_cloud_tpu 318 self.has_chief = has_chief 319 self.num_workers = num_workers 320 self.num_ps = num_ps 321 self.share_gpu = share_gpu 322 self._pool_runner_fn = pool_runner_fn 323 self.no_xla = no_xla 324 325 @property 326 def runner(self): 327 if self._pool_runner_fn is not None: 328 return self._pool_runner_fn() 329 return None 330 331 @property 332 def strategy(self): 333 return self._distribution_fn() 334 335 def __repr__(self): 336 return self._name 337 338 339# This is to allow adding combinations that runs a function both as a 340# tf.function and eagerly. 341# 342# @combinations.generate( 343# combinations.combine( 344# tf_function = [combinations.tf_function, combinations.no_tf_function] 345# ) 346# ) 347# def testXXX(tf_function): 348# @tf_function 349# def foo(): 350# tf.add(1., 1.) 351# 352# foo() 353tf_function = combinations_lib.NamedObject("TfFunction", def_function.function) 354no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f) 355 356 357def concat(*combined): 358 """Concats combinations.""" 359 result = [] 360 for one in combined: 361 result += one 362 return result 363 364 365@tf_export("__internal__.distribute.combinations.generate", v1=[]) 366def generate(combinations, test_combinations=()): 367 # pylint: disable=g-doc-args,g-doc-return-or-yield 368 """Distributed adapter of `tf.__internal__.test.combinations.generate`. 369 370 All tests with distributed strategy should use this one instead of 371 `tf.__internal__.test.combinations.generate`. This function has support of 372 strategy combinations, GPU/TPU and multi worker support. 373 374 See `tf.__internal__.test.combinations.generate` for usage. 375 """ 376 # pylint: enable=g-doc-args,g-doc-return-or-yield 377 default_combinations = ( 378 framework_combinations.EagerGraphCombination(), 379 framework_combinations.TFVersionCombination(), 380 ClusterCombination(), 381 DistributionCombination(), 382 GPUCombination(), 383 TPUCombination(), 384 ) 385 # We apply our own decoration to handle multi worker tests before applying 386 # framework.test_combinations.generate. The order is important since we need 387 # framework.test_combinations.generate to apply all parameter modifiers first. 388 combination_decorator = combinations_lib.generate( 389 combinations, test_combinations=default_combinations + test_combinations) 390 391 def decorator(test_method_or_class): 392 if isinstance(test_method_or_class, type): 393 # If it's a test class. 394 class_object = test_method_or_class 395 # Decorate each test method with _multi_worker_test. 396 for name, test_method in six.iteritems(class_object.__dict__.copy()): 397 if (name.startswith(unittest.TestLoader.testMethodPrefix) and 398 isinstance(test_method, types.FunctionType)): 399 setattr(class_object, name, _multi_worker_test(test_method)) 400 return combination_decorator(class_object) 401 else: 402 return combination_decorator(_multi_worker_test(test_method_or_class)) 403 404 return decorator 405 406 407combine = combinations_lib.combine 408times = combinations_lib.times 409NamedObject = combinations_lib.NamedObject 410 411 412# Identifies whether we're in the main process or worker processes. 413# `_multi_worker_test` decoration behaves differently in the main processs and 414# the worker processes. See the documentation of _multi_worker_test for detail. 415_running_in_worker = False 416 417 418@tf_export("__internal__.distribute.combinations.in_main_process", v1=[]) 419def in_main_process(): 420 """Whether it's in the main test process. 421 422 This is normally used to prepare the test environment which should only happen 423 in the main process. 424 425 Returns: 426 A boolean. 427 """ 428 return not _running_in_worker 429 430 431class TestEnvironment(object): 432 """Holds the test environment information. 433 434 Tests should modify the attributes of the instance returned by `env()` in the 435 main process if needed, and it will be passed to the worker processes each 436 time a test case is run. 437 """ 438 439 def __init__(self): 440 self.tf_data_service_dispatcher = None 441 # Note that this includes GPUs that may not be visible to the current 442 # worker. 443 self.total_phsyical_gpus = None 444 445 def __setattr__(self, name, value): 446 if not in_main_process(): 447 raise ValueError( 448 "combinations.env() should only be modified in the main process. " 449 "Condition your code on combinations.in_main_process().") 450 super().__setattr__(name, value) 451 452 453_env = TestEnvironment() 454 455 456@tf_export("__internal__.distribute.combinations.env", v1=[]) 457def env(): 458 """Returns the object holds the test environment information. 459 460 Tests should modify this in the main process if needed, and it will be passed 461 to the worker processes each time a test case is run. 462 463 Returns: 464 a TestEnvironment object. 465 """ 466 return _env 467 468 469def _set_total_phsyical_gpus(): 470 if in_main_process(): 471 env().total_phsyical_gpus = len( 472 context.context().list_physical_devices("GPU")) 473 474 475# This is needed in case CUDA is lazily loaded. 476app.call_after_init(_set_total_phsyical_gpus) 477 478 479_TestResult = collections.namedtuple("_TestResult", ["status", "message"]) 480 481 482def _test_runner(test_id, test_env): 483 """Executes the test with the given test_id. 484 485 This is a simple wrapper around TestRunner to be used with 486 multi_process_runner. Similar to test.main(), but it executes only one test 487 specified by test_id and returns whether the test succeeds. If the test fails, 488 the function prints failures and errors to stdout. 489 490 Args: 491 test_id: TestCase.id() 492 test_env: a TestEnvironment object. 493 494 Returns: 495 A boolean indicates whether the test succeeds. 496 """ 497 global _running_in_worker, _env 498 # No need to restore the value of _running_in_worker since it should always be 499 # True in worker processes. 500 _running_in_worker = True 501 _env = test_env 502 test = unittest.defaultTestLoader.loadTestsFromName(test_id) 503 runner = unittest.TextTestRunner() 504 result = runner.run(test) 505 # Treat expected failures as failures, so that the main process can get 506 # them and fail as expected. Also treat errors as failures to simplify the 507 # handling. 508 failures = result.failures + result.expectedFailures + result.errors 509 if failures: 510 ret = _TestResult(status="failure", message=failures[0][1]) 511 elif result.skipped: 512 ret = _TestResult(status="skipped", message=result.skipped[0][1]) 513 else: 514 # Treat unexpectedSuccesses as OK so that the test case in the main process 515 # succeed as well. 516 ret = _TestResult(status="ok", message=None) 517 # Print tracebacks to stdout and multi_process_runner will collect 518 # them and stream back to the main process. 519 if ret.message: 520 print(ret.message) 521 return ret 522 523 524def _multi_worker_test(test_method): 525 """Decorate test_method so that it runs in each worker. 526 527 We use `multi_process_runner` to simulate multiple workers. Since we run the 528 this function in the main process and all worker processes, this decoration 529 behaves differently in the main process and worker procssses. In the main 530 process, it spawns subprocesses and runs the test on each of them; in a worker 531 process, it executes test in the same way as a normal test, e.g. 532 setUp()/tearDown() are called before/after the test. 533 534 Args: 535 test_method: a function which must be a test method. 536 537 Returns: 538 Decorated `test_method`. Note that the decorated function has additional 539 arguments. 540 """ 541 542 def decorator(self, has_chief, num_workers, num_ps, share_gpu, runner, 543 **kwargs): 544 if _num_total_workers(has_chief, 545 num_workers) == 1 or _running_in_worker or ( 546 # Use in-process cluster for PS combinations 547 # when XLA is enabled. 548 test_util.is_xla_enabled() and num_ps > 0): 549 # We're in worker process or the test is for single worker. Either case we 550 # execute the test method directly instead of spawning subprocesses. 551 552 # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a 553 # session that connects to the local server. This is necessary for multi 554 # worker graph mode tests to work. Those tests cannot use their graphs or 555 # sessions, including the one returned by self.cached_session(). Since 556 # existing tests may already be doing so, we only install the session for 557 # multi worker tests. 558 with _multi_worker_session(kwargs): 559 test_method(self, **kwargs) 560 return 561 562 # We're in the main process. We spawn subprocesses and run the *test* on 563 # each of them. Note that we're not directly executing test_method passed to 564 # _multi_worker_test, because we need setUp()/tearDown() to be called and 565 # all the decorations on the test method. The conceptual call stack is: 566 # [main process]test.main() 567 # [main process]test_runner.run(test) 568 # [main process]wrapper by combinations.generate() 569 # [main process]_multi_worker_test.decorator() 570 # # A sub process goes through the same code path as the main 571 # # process. 572 # [sub process]_test_runner() 573 # [sub process]test_runner.run(test) 574 # [sub process]wrapper by combinations.generate() 575 # [sub process]_multi_worker_test.decorator() 576 # # _running_in_worker is True 577 # [sub process]test_method() 578 test_id = self.id() 579 if runner: 580 results = runner.run(_test_runner, args=(test_id, _env)) 581 else: 582 cluster_spec = multi_worker_test_base.create_cluster_spec( 583 has_chief=has_chief, 584 num_workers=num_workers, 585 num_ps=num_ps, 586 has_eval=False) 587 ephemeral_runner = multi_process_runner.MultiProcessRunner( 588 _test_runner, 589 cluster_spec, 590 share_gpu=share_gpu, 591 args=(test_id, _env), 592 dependence_on_chief=has_chief) 593 ephemeral_runner.start() 594 results = ephemeral_runner.join().return_value 595 596 skip_reason = None 597 for result in results: 598 if result.status == "failure": 599 # We can't tell which worker the return value come from, so we fail on 600 # the first error. 601 self.fail(result.message) 602 break 603 elif result.status == "skipped": 604 # Record the skip reason, but do not actually skip the test in case some 605 # processes fail instead. 606 skip_reason = result.message 607 if skip_reason is not None: 608 self.skipTest(skip_reason) 609 610 argspec = tf_inspect.getfullargspec(test_method) 611 decorator_args = (argspec.args or []) + [ 612 "has_chief", "num_workers", "num_ps", "share_gpu", "runner" 613 ] 614 decorator_argspec = argspec._replace(args=decorator_args) 615 return tf_decorator.make_decorator( 616 test_method, decorator, decorator_argspec=decorator_argspec) 617 618 619def _num_total_workers(has_chief, num_workers): 620 """Returns the number of workers including the chief.""" 621 if has_chief: 622 return num_workers + 1 623 return num_workers 624 625 626def _multi_worker_session(kwargs): 627 """Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy. 628 629 Args: 630 kwargs: a dict. Keyword arguments passed to the test. 631 632 Returns: 633 A context manager. If MultiWorkerMirroredStrategy is the one and only one 634 strategy in kwargs and it's in graph mode, it's the seesion that is 635 configured for that strategy. Otherwise, it's a no-op context manager. 636 """ 637 strategy = None 638 for _, v in kwargs.items(): 639 if isinstance(v, distribute_lib.StrategyBase): 640 if strategy is not None: 641 logging.warning( 642 "The test uses multiple strategies. Skipping " 643 "entering a session that is configured for the strategy.") 644 return ops.NullContextmanager() 645 strategy = v 646 if context.executing_eagerly() or not isinstance( 647 strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy): 648 return ops.NullContextmanager() 649 sess_config = copy.deepcopy(context.context().config) 650 sess_config = strategy.update_config_proto(sess_config) 651 target = strategy.cluster_resolver.master() 652 return session.Session(config=sess_config, target=target).as_default() 653