xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/strategy_combinations.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Strategy combinations for combinations.combine()."""
16
17import sys
18import unittest
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.python import tf2
21from tensorflow.python.distribute import central_storage_strategy
22from tensorflow.python.distribute import cluster_resolver
23from tensorflow.python.distribute import collective_all_reduce_strategy
24from tensorflow.python.distribute import combinations
25from tensorflow.python.distribute import distribution_strategy_context
26from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
27from tensorflow.python.distribute import multi_process_runner
28from tensorflow.python.distribute import multi_worker_test_base
29from tensorflow.python.distribute import one_device_strategy as one_device_lib
30from tensorflow.python.distribute import parameter_server_strategy_v2
31from tensorflow.python.distribute import sharded_variable
32from tensorflow.python.distribute import test_util
33from tensorflow.python.distribute import tpu_strategy as tpu_lib
34from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
35from tensorflow.python.eager import context
36from tensorflow.python.eager import remote
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import test_util as framework_test_util
39from tensorflow.python.platform import flags
40from tensorflow.python.tpu import device_assignment as device_assignment_lib
41from tensorflow.python.tpu import tpu_strategy_util
42from tensorflow.python.training import server_lib
43from tensorflow.python.util.tf_export import tf_export
44
45_TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations."
46
47_did_connect_to_cluster = False
48_topology = None
49CollectiveAllReduceExtended = (
50    collective_all_reduce_strategy.CollectiveAllReduceExtended)
51
52
53def _version_chooser(tf1_cls, tf2_cls):
54
55  def creator(*args, **kwargs):
56    if tf2.enabled():
57      return tf2_cls(*args, **kwargs)
58    return tf1_cls(*args, **kwargs)
59
60  return creator
61
62
63MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1,
64                                    mirrored_lib.MirroredStrategy)
65CentralStorageStrategy = _version_chooser(
66    central_storage_strategy.CentralStorageStrategyV1,
67    central_storage_strategy.CentralStorageStrategy)
68OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1,
69                                     one_device_lib.OneDeviceStrategy)
70# Only V2 CollectiveAllReduceStrategy combinations are supported.
71CollectiveAllReduceStrategy = (
72    collective_all_reduce_strategy.CollectiveAllReduceStrategy)
73
74
75# pylint: disable=missing-docstring
76def _get_tpu_strategy_creator(steps_per_run,
77                              use_single_core=False,
78                              enable_packed_variable=False,
79                              enable_spmd_xla_paritioning=False,
80                              **kwargs):
81
82  def _create_tpu_strategy():
83    FLAGS = flags.FLAGS  # pylint: disable=invalid-name
84    global _did_connect_to_cluster
85    global _topology
86
87    try:
88      # Attempt to locally discover the TPU. This will fail for Cloud TPU, in
89      # which case we fall back to the values passed as flags.
90      resolver = tpu_cluster_resolver.TPUClusterResolver()
91      did_automatically_resolve = True
92    except ValueError:
93      did_automatically_resolve = False
94
95      # These flags will be defined by tpu_test_wrapper.py.
96      resolver = tpu_cluster_resolver.TPUClusterResolver(
97          tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
98          zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
99          project=hasattr(FLAGS, "project") and FLAGS.project or None,
100      )
101
102    # Only connect once per process, rather than per test method.
103    if not _did_connect_to_cluster:
104      if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
105        remote.connect_to_cluster(resolver)
106        _did_connect_to_cluster = True
107      _topology = tpu_strategy_util.initialize_tpu_system(resolver)
108
109    device_assignment = None
110    if use_single_core:
111      device_assignment = device_assignment_lib.DeviceAssignment(
112          _topology,
113          core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
114
115    # Steps per run is only supported in TF 1.x
116    if tf2.enabled():
117      strategy = tpu_lib.TPUStrategyV2(
118          resolver,
119          device_assignment,
120          experimental_spmd_xla_partitioning=enable_spmd_xla_paritioning,
121          **kwargs)
122    else:
123      strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
124                                       device_assignment, **kwargs)
125    if enable_packed_variable and enable_spmd_xla_paritioning:
126      raise ValueError("Packed Variable is not compatiable with SPMD mode")
127    strategy._enable_packed_variable_in_eager_mode = enable_packed_variable  # pylint: disable=protected-access
128    return strategy
129
130  return _create_tpu_strategy
131
132
133def _mirrored_strategy_with_collective_key_base(devices):
134  mirrored_lib.MirroredStrategyV1._collective_key_base += 100000
135  mirrored_lib.MirroredStrategy._collective_key_base += 100000
136  return MirroredStrategy(devices)
137
138
139def _mirrored_strategy_with_no_merge_call(devices):
140  mirrored_lib.MirroredStrategyV1._collective_key_base += 100000
141  mirrored_lib.MirroredStrategy._collective_key_base += 100000
142  out = MirroredStrategy(devices)
143  # Stub out merge call usage.
144  out.extended._use_merge_call = lambda: False  # pylint: disable=protected-access
145  return out
146
147
148def _get_multi_worker_mirrored_creator(required_gpus, use_merge_call=True):
149
150  def _create_multi_worker_mirrored():
151    tf_config = cluster_resolver.TFConfigClusterResolver()
152    master = tf_config.master()
153    if tf_config.rpc_layer:
154      # Strip off the rpc_layer suffix.
155      master = master[len("%s://" % tf_config.rpc_layer):]
156    resolver = cluster_resolver.SimpleClusterResolver(
157        cluster_spec=tf_config.cluster_spec(),
158        task_type=tf_config.task_type,
159        task_id=tf_config.task_id,
160        master=master,
161        environment=tf_config.environment,
162        num_accelerators={"GPU": required_gpus},
163        rpc_layer=tf_config.rpc_layer or "grpc",
164    )
165    # Disable health check and coordination service. We don't have a reliable
166    # way to shutdown the strategy (and thus the strategy health check or
167    # coordination service heartbeat) at the end of a test. Turning on the
168    # strategy health check or coordination service heartbeat causes some
169    # flakiness since we re-create part of the server when creating a strategy,
170    # and our tests are capable of handling failures.
171    CollectiveAllReduceExtended._enable_check_health = False  # pylint: disable=protected-access
172    context.context().configure_coordination_service(service_type="")
173    # Always create the strategy in eager mode so that it starts the server and
174    # configures the eager context. The eager context can no longer be
175    # configured after initialization.
176    with context.eager_mode():
177      strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver)
178
179    if not use_merge_call:
180      strategy.extended._use_merge_call = lambda: False  # pylint: disable=protected-access
181    # TODO(b/152320929): Wait for the cluster before proceeding, otherwise
182    # collectives may hang if any worker launches collectives before the chief
183    # creates the strategy.
184    try:
185      multi_process_runner.get_barrier().wait()
186    except ValueError:
187      # If the creator is called in the main process,
188      # multi_process_runner.get_barrier() raises ValueError, which is safe to
189      # ignore.
190      pass
191    return strategy
192
193  def skip_if_cannot_start_grpc_server():
194    try:
195      return _create_multi_worker_mirrored()
196    except errors.UnknownError as e:
197      if "Could not start gRPC server" in e.message and (
198          len(sys.argv) >= 1 and "bazel" in sys.argv[0]):
199        raise unittest.SkipTest("Cannot start std servers.")
200      else:
201        raise
202
203  return skip_if_cannot_start_grpc_server
204
205
206# Due to b/195615322, FixedShardsPartitioner will wrongly partition
207# RNG state, so we use MinSizePartitioner as the default. Maximum RNG
208# state size is int64[3] which is 8 * 3 bytes, so we set
209# min_shard_bytes to 8 * 3 + 1.
210DEFAULT_PARTITIONER = sharded_variable.MinSizePartitioner(
211    min_shard_bytes=8 * 3 + 1, max_shards=2)
212
213
214def _get_ps_strategy_creator(num_workers,
215                             num_ps,
216                             required_gpus=0,
217                             variable_partitioner=DEFAULT_PARTITIONER):
218
219  def _create_ps_strategy(resolver, variable_partitioner):
220    return parameter_server_strategy_v2.ParameterServerStrategyV2(
221        resolver, variable_partitioner=variable_partitioner)
222
223  def _create_parameter_server():
224    if framework_test_util.is_xla_enabled():
225      # To address test failures resulting in XLA with MultiProcessRunner,
226      # continue to use in-process cluster for XLA tests.
227      cluster_def = multi_worker_test_base.create_in_process_cluster(
228          num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
229      resolver = cluster_resolver.SimpleClusterResolver(
230          server_lib.ClusterSpec(cluster_def),
231          num_accelerators={"GPU": required_gpus},
232          rpc_layer="grpc")
233      return _create_ps_strategy(resolver, variable_partitioner)
234    else:
235      tf_config = cluster_resolver.TFConfigClusterResolver()
236      cluster_def = tf_config.cluster_spec().as_dict()
237      if not cluster_def:
238        # When MultiProcessRunner cluster is used, the cluster is not created
239        # initially when the decorator is called. When the test runs, initially
240        # this method is invoked via decorator before setting up the
241        # MultiProcessRunner with worker and ps in the combinations.py. After
242        # setup is done, the subprocess invokes this method again to get
243        # strategy object. We return None strategy when the main thread invokes
244        # this method before setting up cluster.
245        # Returning None is fine here, since this thread will proceed to create
246        # MultiProcessRunner and invoke tests with decorator inside
247        # subprocesses.
248        return None
249      # MultiProcessRunner is already setup and this method is invoked from a
250      # subprocess running the actual test.
251      resolver = cluster_resolver.SimpleClusterResolver(
252          server_lib.ClusterSpec(cluster_def),
253          num_accelerators={"GPU": required_gpus},
254          task_type=tf_config.task_type,
255          task_id=tf_config.task_id,
256          environment=tf_config.environment,
257          rpc_layer=tf_config.rpc_layer or "grpc")
258      if tf_config.task_type in ("worker", "ps"):
259        worker_config = config_pb2.ConfigProto()
260        worker_config.inter_op_parallelism_threads = 4  # max num_workers + 1
261
262        try:
263          server = server_lib.Server(
264              cluster_def,
265              job_name=tf_config.task_type,
266              task_index=tf_config.task_id,
267              protocol="grpc",
268              config=worker_config)
269        except errors.UnknownError as e:
270          if "Could not start gRPC server" in e.message:
271            raise unittest.SkipTest("Cannot start std servers.")
272          else:
273            raise
274
275        # Blocking the process that starts a server from exiting.
276        server.join()
277
278      return _create_ps_strategy(resolver, variable_partitioner)
279
280  return _create_parameter_server
281
282
283def _deferred_pool_runner(has_chief,
284                          num_workers,
285                          initializer=None,
286                          share_gpu=True):
287  """Returns a callable that returns the pool runner.
288
289  It creates the pool runner only upon first invocation. This avoids creating it
290  when this file is imported.
291
292  Args:
293    has_chief: whether there should be a chief.
294    num_workers: the number of workers excluding the chief.
295    initializer: initializer of each process.
296    share_gpu: whether to share GPU between the workers.
297
298  Returns:
299    A callable that returns the runner.
300  """
301
302  container = []
303
304  def get_or_create():
305    if not container:
306      cluster_spec = multi_worker_test_base.create_cluster_spec(
307          has_chief=has_chief,
308          num_workers=num_workers,
309          num_ps=0,
310          has_eval=False)
311      runner = multi_process_runner.MultiProcessPoolRunner(
312          cluster_spec, initializer=initializer, share_gpu=share_gpu)
313      container.append(runner)
314    return container[0]
315
316  return get_or_create
317
318
319# We need to create the strategy in the initializer to start the server before
320# any test runs.
321_two_worker_pool = _deferred_pool_runner(
322    has_chief=True,
323    num_workers=1,
324    initializer=_get_multi_worker_mirrored_creator(required_gpus=0))
325
326# Two-worker pool where each worker gets it's own GPU. Useful for testing MWMS
327# on a single host.
328_two_worker_pool_noshare = _deferred_pool_runner(
329    has_chief=True,
330    num_workers=1,
331    initializer=_get_multi_worker_mirrored_creator(required_gpus=0),
332    share_gpu=False)
333_four_worker_pool = _deferred_pool_runner(
334    has_chief=True,
335    num_workers=3,
336    initializer=_get_multi_worker_mirrored_creator(required_gpus=0))
337
338# pylint: disable=g-long-lambda
339default_strategy = combinations.NamedDistribution(
340    "Default",
341    distribution_strategy_context._get_default_strategy,  # pylint: disable=protected-access
342    required_gpus=None)
343one_device_strategy = combinations.NamedDistribution(
344    "OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None)
345one_device_strategy_gpu = combinations.NamedDistribution(
346    "OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1)
347one_device_strategy_on_worker_1 = combinations.NamedDistribution(
348    "OneDeviceOnWorker1CPU",
349    lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"),
350    required_gpus=None)
351one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
352    "OneDeviceOnWorker1GPU",
353    lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"),
354    required_gpus=1)
355tpu_strategy = combinations.NamedDistribution(
356    "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
357tpu_strategy_packed_var = combinations.NamedDistribution(
358    "TPUPackedVar",
359    _get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True),
360    required_tpu=True)
361tpu_strategy_spmd = combinations.NamedDistribution(
362    "TPUUseSPMD",
363    _get_tpu_strategy_creator(
364        steps_per_run=2, enable_spmd_xla_paritioning=True),
365    required_tpu=True)
366tpu_strategy_one_step = combinations.NamedDistribution(
367    "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
368tpu_strategy_one_core = combinations.NamedDistribution(
369    "TPUOneCore",
370    _get_tpu_strategy_creator(steps_per_run=2, use_single_core=True),
371    required_tpu=True)
372tpu_strategy_one_step_one_core = combinations.NamedDistribution(
373    "TPUOneStepOneCore",
374    _get_tpu_strategy_creator(steps_per_run=1, use_single_core=True),
375    required_tpu=True)
376cloud_tpu_strategy = combinations.NamedDistribution(
377    "CloudTPU",
378    _get_tpu_strategy_creator(steps_per_run=2),
379    required_tpu=True,
380    use_cloud_tpu=True)
381mirrored_strategy_with_one_cpu = combinations.NamedDistribution(
382    "Mirrored1CPU",
383    lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0"]))
384mirrored_strategy_with_one_gpu = combinations.NamedDistribution(
385    "Mirrored1GPU",
386    lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0"]),
387    required_gpus=1)
388mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
389    "MirroredCPUAndGPU",
390    lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/cpu:0"]),
391    required_gpus=1)
392mirrored_strategy_with_two_cpus = combinations.NamedDistribution(
393    "Mirrored2CPUs",
394    lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0", "/cpu:1"]),
395    required_gpus=0)
396mirrored_strategy_with_two_gpus = combinations.NamedDistribution(
397    "Mirrored2GPUs",
398    lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/gpu:1"]),
399    required_gpus=2)
400mirrored_strategy_with_two_gpus_no_merge_call = combinations.NamedDistribution(
401    "Mirrored2GPUsNoMergeCall",
402    lambda: _mirrored_strategy_with_no_merge_call(["/gpu:0", "/gpu:1"]),
403    required_physical_gpus=2)
404# Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods.
405# Deprecated, use mirrored_strategy_with_two_cpus instead.
406mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution(
407    "Mirrored2CPU",
408    lambda: _mirrored_strategy_with_collective_key_base(["/cpu:1", "/cpu:2"]))
409mirrored_strategy_with_cpu_1_and_2.__doc__ = (
410    """Mirrored strategy with 2 virtual CPUs.
411
412    Should set up logical devices before use
413    """)
414central_storage_strategy_with_two_gpus = combinations.NamedDistribution(
415    "CentralStorage2GPUs",
416    lambda: CentralStorageStrategy(["/gpu:0", "/gpu:1"]),
417    required_gpus=2)
418central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
419    "CentralStorageCPUAndGPU",
420    lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]),
421    required_gpus=1)
422# chief + 1 worker, with CPU.
423multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution(
424    "MultiWorkerMirrored2x1CPU",
425    _get_multi_worker_mirrored_creator(required_gpus=0),
426    has_chief=True,
427    num_workers=1,
428    pool_runner_fn=_two_worker_pool,
429    no_xla=True,
430)
431# chief + 1 worker, with 1 GPU each.
432multi_worker_mirrored_2x1_gpu = combinations.NamedDistribution(
433    "MultiWorkerMirrored2x1GPU",
434    _get_multi_worker_mirrored_creator(required_gpus=1),
435    has_chief=True,
436    num_workers=1,
437    required_gpus=1,
438    pool_runner_fn=_two_worker_pool,
439    share_gpu=False,
440)
441
442# Same as above, but not sharing the GPU between the workers.
443multi_worker_mirrored_2x1_gpu_noshare = combinations.NamedDistribution(
444    "MultiWorkerMirrored2x1GPUNoShare",
445    _get_multi_worker_mirrored_creator(required_gpus=1),
446    has_chief=True,
447    num_workers=1,
448    required_gpus=1,
449    pool_runner_fn=_two_worker_pool_noshare,
450    share_gpu=False,
451)
452# chief + 1 worker, with 2 GPU each.
453multi_worker_mirrored_2x2_gpu = combinations.NamedDistribution(
454    "MultiWorkerMirrored2x2GPU",
455    _get_multi_worker_mirrored_creator(required_gpus=2),
456    has_chief=True,
457    num_workers=1,
458    required_gpus=2,
459    pool_runner_fn=_two_worker_pool,
460    no_xla=True,
461)
462multi_worker_mirrored_2x2_gpu_no_merge_call = combinations.NamedDistribution(
463    "MultiWorkerMirrored2x2GPUNoMergeCall",
464    _get_multi_worker_mirrored_creator(required_gpus=2, use_merge_call=False),
465    has_chief=True,
466    num_workers=1,
467    required_physical_gpus=2,
468    pool_runner_fn=_two_worker_pool,
469    no_xla=True,
470)
471# chief + 3 workers, with CPU.
472multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution(
473    "MultiWorkerMirrored4x1CPU",
474    _get_multi_worker_mirrored_creator(required_gpus=0),
475    has_chief=True,
476    num_workers=3,
477    pool_runner_fn=_four_worker_pool,
478    no_xla=True,
479)
480
481
482def parameter_server_strategy_fn(name,
483                                 num_workers,
484                                 num_ps,
485                                 required_gpus=0,
486                                 variable_partitioner=DEFAULT_PARTITIONER):
487  return combinations.NamedDistribution(
488      name,
489      _get_ps_strategy_creator(
490          num_workers=num_workers,
491          num_ps=num_ps,
492          required_gpus=required_gpus,
493          variable_partitioner=variable_partitioner),
494      required_gpus=required_gpus,
495      num_workers=num_workers,
496      has_chief=True,
497      num_ps=num_ps)
498
499
500parameter_server_strategy_3worker_2ps_cpu = parameter_server_strategy_fn(
501    "ParameterServer3Worker2PSCPU", num_workers=3, num_ps=2)
502parameter_server_strategy_1worker_2ps_cpu = parameter_server_strategy_fn(
503    "ParameterServer1Worker2PSCPU", num_workers=1, num_ps=2)
504parameter_server_strategy_3worker_2ps_1gpu = parameter_server_strategy_fn(
505    "ParameterServer3Worker2PS1GPU", num_workers=3, num_ps=2, required_gpus=1)
506parameter_server_strategy_1worker_2ps_1gpu = parameter_server_strategy_fn(
507    "ParameterServer1Worker2PS1GPU", num_workers=1, num_ps=2, required_gpus=1)
508
509graph_and_eager_modes = ["graph", "eager"]
510
511
512# TODO(crccw): remove after tf-nightly picks up the new API.
513def set_virtual_cpus_to_at_least(num_virtual_cpus):
514  test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus)
515
516
517strategies_minus_tpu = [
518    default_strategy,
519    one_device_strategy,
520    one_device_strategy_gpu,
521    mirrored_strategy_with_gpu_and_cpu,
522    mirrored_strategy_with_two_gpus,
523    central_storage_strategy_with_gpu_and_cpu,
524]
525
526strategies_minus_default_and_tpu = [
527    one_device_strategy,
528    one_device_strategy_gpu,
529    mirrored_strategy_with_gpu_and_cpu,
530    mirrored_strategy_with_two_gpus,
531]
532
533tpu_strategies = [
534    tpu_strategy,  # steps_per_run=2
535    tpu_strategy_one_step,
536    tpu_strategy_packed_var,
537    cloud_tpu_strategy,
538]
539
540all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies
541
542all_strategies = strategies_minus_tpu + tpu_strategies
543
544two_replica_strategies = [
545    mirrored_strategy_with_gpu_and_cpu,
546    mirrored_strategy_with_two_gpus,
547    multi_worker_mirrored_2x1_cpu,
548    multi_worker_mirrored_2x1_gpu,
549    tpu_strategy,  # steps_per_run=2
550    tpu_strategy_one_step,
551    central_storage_strategy_with_gpu_and_cpu,
552]
553
554four_replica_strategies = [
555    multi_worker_mirrored_2x2_gpu,
556    multi_worker_mirrored_4x1_cpu,
557]
558
559# TODO(b/159831907): replace with two_replica_strategies after the tests using
560# it work with MWMS.
561multidevice_strategies = [
562    mirrored_strategy_with_gpu_and_cpu,
563    mirrored_strategy_with_two_gpus,
564    tpu_strategy,  # steps_per_run=2
565    tpu_strategy_one_step
566]
567
568multiworker_strategies = [
569    multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu,
570    multi_worker_mirrored_2x2_gpu
571]
572
573
574def strategy_minus_tpu_combinations():
575  return combinations.combine(
576      distribution=strategies_minus_tpu, mode=["graph", "eager"])
577
578
579def tpu_strategy_combinations():
580  return combinations.combine(distribution=tpu_strategies, mode=["graph"])
581
582
583def all_strategy_combinations():
584  return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
585
586
587def all_strategy_minus_default_and_tpu_combinations():
588  return combinations.combine(
589      distribution=[
590          one_device_strategy, one_device_strategy_gpu,
591          mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
592      ],
593      mode=["graph", "eager"])
594
595
596def all_strategy_combinations_minus_default():
597  return (all_strategy_minus_default_and_tpu_combinations() +
598          tpu_strategy_combinations())
599
600
601tf_export(
602    _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_gpu_and_cpu",
603    v1=[]).export_constant(__name__,
604                           "central_storage_strategy_with_gpu_and_cpu")
605tf_export(
606    _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_two_gpus",
607    v1=[]).export_constant(__name__, "central_storage_strategy_with_two_gpus")
608tf_export(
609    _TF_INTERNAL_API_PREFIX + "cloud_tpu_strategy",
610    v1=[]).export_constant(__name__, "cloud_tpu_strategy")
611tf_export(
612    _TF_INTERNAL_API_PREFIX + "default_strategy",
613    v1=[]).export_constant(__name__, "default_strategy")
614tf_export(
615    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_cpu_1_and_2",
616    v1=[]).export_constant(__name__, "mirrored_strategy_with_cpu_1_and_2")
617tf_export(
618    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_gpu_and_cpu",
619    v1=[]).export_constant(__name__, "mirrored_strategy_with_gpu_and_cpu")
620tf_export(
621    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_cpu",
622    v1=[]).export_constant(__name__, "mirrored_strategy_with_one_cpu")
623tf_export(
624    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_gpu",
625    v1=[]).export_constant(__name__, "mirrored_strategy_with_one_gpu")
626tf_export(
627    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus",
628    v1=[]).export_constant(__name__, "mirrored_strategy_with_two_gpus")
629tf_export(
630    _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus_no_merge_call",
631    v1=[]).export_constant(__name__,
632                           "mirrored_strategy_with_two_gpus_no_merge_call")
633tf_export(
634    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_cpu",
635    v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_cpu")
636tf_export(
637    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu",
638    v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu")
639tf_export(
640    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu_noshare",
641    v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu_noshare")
642tf_export(
643    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu",
644    v1=[]).export_constant(__name__, "multi_worker_mirrored_2x2_gpu")
645tf_export(
646    _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu_no_merge_call",
647    v1=[]).export_constant(__name__,
648                           "multi_worker_mirrored_2x2_gpu_no_merge_call")
649tf_export(
650    _TF_INTERNAL_API_PREFIX + "one_device_strategy",
651    v1=[]).export_constant(__name__, "one_device_strategy")
652tf_export(
653    _TF_INTERNAL_API_PREFIX + "one_device_strategy_gpu",
654    v1=[]).export_constant(__name__, "one_device_strategy_gpu")
655tf_export(
656    _TF_INTERNAL_API_PREFIX + "tpu_strategy",
657    v1=[]).export_constant(__name__, "tpu_strategy")
658tf_export(
659    _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_3worker_2ps_cpu",
660    v1=[]).export_constant(__name__,
661                           "parameter_server_strategy_3worker_2ps_cpu")
662tf_export(
663    _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_1worker_2ps_cpu",
664    v1=[]).export_constant(__name__,
665                           "parameter_server_strategy_1worker_2ps_cpu")
666tf_export(
667    _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_3worker_2ps_1gpu",
668    v1=[]).export_constant(__name__,
669                           "parameter_server_strategy_3worker_2ps_1gpu")
670tf_export(
671    _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_1worker_2ps_1gpu",
672    v1=[]).export_constant(__name__,
673                           "parameter_server_strategy_1worker_2ps_1gpu")
674tf_export(
675    _TF_INTERNAL_API_PREFIX + "tpu_strategy_one_core",
676    v1=[]).export_constant(__name__, "tpu_strategy_one_core")
677tf_export(
678    _TF_INTERNAL_API_PREFIX + "tpu_strategy_packed_var",
679    v1=[]).export_constant(__name__, "tpu_strategy_packed_var")
680