1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" 16 17import copy 18import threading 19import time 20import weakref 21 22from tensorflow.core.protobuf import rewriter_config_pb2 23from tensorflow.core.protobuf import tensorflow_server_pb2 24from tensorflow.python.distribute import collective_util 25from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 26from tensorflow.python.distribute import cross_device_utils 27from tensorflow.python.distribute import device_util 28from tensorflow.python.distribute import distribute_lib 29from tensorflow.python.distribute import distribute_utils 30from tensorflow.python.distribute import distribution_strategy_context as ds_context 31from tensorflow.python.distribute import input_lib 32from tensorflow.python.distribute import input_util 33from tensorflow.python.distribute import mirrored_strategy 34from tensorflow.python.distribute import multi_worker_util 35from tensorflow.python.distribute import numpy_dataset 36from tensorflow.python.distribute import reduce_util 37from tensorflow.python.distribute import values 38from tensorflow.python.distribute.cluster_resolver import ClusterResolver 39from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 40from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 41from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 42from tensorflow.python.eager import context 43from tensorflow.python.framework import device as tf_device 44from tensorflow.python.framework import errors 45from tensorflow.python.framework import ops 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import collective_ops 48from tensorflow.python.ops import control_flow_util 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.tpu import tpu_strategy_util 51from tensorflow.python.trackable import base 52from tensorflow.python.util import deprecation 53from tensorflow.python.util.tf_export import tf_export 54 55 56# pylint: disable=line-too-long 57@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[]) 58class CollectiveAllReduceStrategy(distribute_lib.Strategy): 59 """A distribution strategy for synchronous training on multiple workers. 60 61 This strategy implements synchronous distributed training across multiple 62 workers, each with potentially multiple GPUs. Similar to 63 `tf.distribute.MirroredStrategy`, it replicates all variables and computations 64 to each local device. The difference is that it uses a distributed collective 65 implementation (e.g. all-reduce), so that multiple workers can work together. 66 67 You need to launch your program on each worker and configure 68 `cluster_resolver` correctly. For example, if you are using 69 `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to 70 have its corresponding `task_type` and `task_id` set in the `TF_CONFIG` 71 environment variable. An example TF_CONFIG on worker-0 of a two worker cluster 72 is: 73 74 ``` 75 TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' 76 ``` 77 78 Your program runs on each worker as-is. Note that collectives require each 79 worker to participate. All `tf.distribute` and non `tf.distribute` API may use 80 collectives internally, e.g. checkpointing and saving since reading a 81 `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value. 82 Therefore it's recommended to run exactly the same program on each worker. 83 Dispatching based on `task_type` or `task_id` of the worker is error-prone. 84 85 `cluster_resolver.num_accelerators()` determines the number of GPUs the 86 strategy uses. If it's zero, the strategy uses the CPU. All workers need to 87 use the same number of devices, otherwise the behavior is undefined. 88 89 This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy` 90 instead. 91 92 After setting up TF_CONFIG, using this strategy is similar to using 93 `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`. 94 95 ``` 96 strategy = tf.distribute.MultiWorkerMirroredStrategy() 97 98 with strategy.scope(): 99 model = tf.keras.Sequential([ 100 tf.keras.layers.Dense(2, input_shape=(5,)), 101 ]) 102 optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 103 104 def dataset_fn(ctx): 105 x = np.random.random((2, 5)).astype(np.float32) 106 y = np.random.randint(2, size=(2, 1)) 107 dataset = tf.data.Dataset.from_tensor_slices((x, y)) 108 return dataset.repeat().batch(1, drop_remainder=True) 109 dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) 110 111 model.compile() 112 model.fit(dist_dataset) 113 ``` 114 115 You can also write your own training loop: 116 117 ``` 118 @tf.function 119 def train_step(iterator): 120 121 def step_fn(inputs): 122 features, labels = inputs 123 with tf.GradientTape() as tape: 124 logits = model(features, training=True) 125 loss = tf.keras.losses.sparse_categorical_crossentropy( 126 labels, logits) 127 128 grads = tape.gradient(loss, model.trainable_variables) 129 optimizer.apply_gradients(zip(grads, model.trainable_variables)) 130 131 strategy.run(step_fn, args=(next(iterator),)) 132 133 for _ in range(NUM_STEP): 134 train_step(iterator) 135 ``` 136 137 See 138 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) 139 for a detailed tutorial. 140 141 __Saving__ 142 143 You need to save and checkpoint on all workers instead of just one. This is 144 because variables whose synchronization=ON_READ triggers aggregation during 145 saving. It's recommended to save to a different path on each worker to avoid 146 race conditions. Each worker saves the same thing. See 147 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) 148 tutorial for examples. 149 150 __Known Issues__ 151 152 * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the 153 correct number of accelerators. The strategy uses all available GPUs if 154 `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver` 155 or `None`. 156 * In eager mode, the strategy needs to be created before calling any other 157 Tensorflow API. 158 159 """ 160 # pylint: enable=line-too-long 161 162 # TODO(anjalisridhar): Update our guides with examples showing how we can use 163 # the cluster_resolver argument. 164 165 # The starting number for collective keys. This should only be set in tests. 166 _collective_key_base = 0 167 168 def __init__(self, 169 cluster_resolver=None, 170 communication_options=None): 171 """Creates the strategy. 172 173 Args: 174 cluster_resolver: optional 175 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 176 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 177 communication_options: optional 178 `tf.distribute.experimental.CommunicationOptions`. This configures the 179 default options for cross device communications. It can be overridden by 180 options provided to the communication APIs like 181 `tf.distribute.ReplicaContext.all_reduce`. See 182 `tf.distribute.experimental.CommunicationOptions` for details. 183 """ 184 if communication_options is None: 185 communication_options = collective_util.Options() 186 super(CollectiveAllReduceStrategy, self).__init__( 187 CollectiveAllReduceExtended( 188 self, 189 cluster_resolver=cluster_resolver, 190 communication_options=communication_options)) 191 192 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 193 "MultiWorkerMirroredStrategy") 194 # pylint: disable=protected-access 195 distribute_lib.distribution_strategy_replica_gauge.get_cell( 196 "num_workers").set(self.extended._num_workers) 197 distribute_lib.distribution_strategy_replica_gauge.get_cell( 198 "num_replicas_per_worker").set(self.extended._num_devices_per_worker) 199 200 @classmethod 201 def _from_local_devices(cls, devices, communication_options=None): 202 """A convenience method to create an object with a list of devices.""" 203 obj = cls(communication_options=communication_options) 204 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 205 return obj 206 207 @property 208 def cluster_resolver(self): 209 """Returns the cluster resolver associated with this strategy. 210 211 As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy` 212 provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If 213 the user provides one in `__init__`, that instance is returned; if the user 214 does not, a default `TFConfigClusterResolver` is provided. 215 """ 216 return self.extended._cluster_resolver # pylint: disable=protected-access 217 218 219class _CollectiveAllReduceStrategyExperimentalMeta(type): 220 221 @classmethod 222 def __instancecheck__(cls, instance): 223 # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(), 224 # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is 225 # performing such check. 226 return isinstance(instance, CollectiveAllReduceStrategy) 227 228 229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) 230class _CollectiveAllReduceStrategyExperimental( 231 CollectiveAllReduceStrategy, 232 metaclass=_CollectiveAllReduceStrategyExperimentalMeta): 233 234 __doc__ = CollectiveAllReduceStrategy.__doc__ 235 236 @deprecation.deprecated( 237 None, "use distribute.MultiWorkerMirroredStrategy instead") 238 def __init__(self, 239 communication=collective_util.CommunicationImplementation.AUTO, 240 cluster_resolver=None): 241 """Creates the strategy. 242 243 Args: 244 communication: optional 245 `tf.distribute.experimental.CommunicationImplementation`. This is a hint 246 on the preferred collective communication implementation. Possible 247 values include `AUTO`, `RING`, and `NCCL`. 248 cluster_resolver: optional 249 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 250 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 251 """ 252 communication_options = collective_util.Options( 253 implementation=communication) 254 super(_CollectiveAllReduceStrategyExperimental, 255 self).__init__(cluster_resolver, communication_options) 256 257 @classmethod 258 def _from_local_devices( 259 cls, 260 devices, 261 communication=collective_util.CommunicationImplementation.AUTO): 262 """A convenience method to create an object with a list of devices.""" 263 obj = cls(communication) 264 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 265 return obj 266 267 268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__ 269 270 271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring 272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): 273 274 __doc__ = CollectiveAllReduceStrategy.__doc__ 275 276 # The starting number for collective keys. This should only be set in tests. 277 _collective_key_base = 0 278 279 def __init__(self, 280 communication=collective_util.CommunicationImplementation.AUTO, 281 cluster_resolver=None): 282 """Initializes the object.""" 283 communication_options = collective_util.Options( 284 implementation=communication) 285 super(CollectiveAllReduceStrategyV1, self).__init__( 286 CollectiveAllReduceExtended( 287 self, 288 cluster_resolver=cluster_resolver, 289 communication_options=communication_options)) 290 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 291 "MultiWorkerMirroredStrategy") 292 # pylint: disable=protected-access 293 distribute_lib.distribution_strategy_replica_gauge.get_cell( 294 "num_workers").set(self.extended._num_workers) 295 distribute_lib.distribution_strategy_replica_gauge.get_cell( 296 "num_gpu_per_worker").set( 297 self.extended._num_devices_per_worker 298 if self.extended._local_device_type == "GPU" 299 else 0) 300 301 302def _is_gpu_device(device): 303 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 304 305 306class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): 307 """Implementation of CollectiveAllReduceStrategy.""" 308 309 # Whether to perdically check the health of the cluster. If any worker is not 310 # reachable, collectives are aborted and the user program should get a 311 # tf.errors.UnavailableError. It's required to restart in order to recover. 312 _enable_check_health = True 313 # Check health interval in seconds. 314 _check_health_interval = 30 315 # Timeout in seconds for the first check health. The first check health needs 316 # to wait for cluster, which may make a longer time. 317 _check_health_initial_timeout = 0 318 # Times to retry before considering the peer is down. 319 _check_health_retry_limit = 3 320 # Timeout in seconds the each check health. 321 _check_health_timeout = 10 322 323 def __init__(self, container_strategy, cluster_resolver, 324 communication_options, devices=None): 325 if not isinstance(communication_options, collective_util.Options): 326 raise ValueError("communication_options must be an instance of " 327 "tf.distribute.experimental.CommunicationOptions") 328 if cluster_resolver and devices: 329 raise ValueError( 330 "cluster_resolver and devices cannot be set at the same time") 331 332 self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() 333 if not isinstance(self._cluster_resolver, ClusterResolver): 334 raise ValueError("cluster_resolver must be an instance of " 335 "tf.distribute.cluster_resolver.ClusterResolver") 336 distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) 337 self._communication_options = communication_options 338 self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access 339 self._initialize_strategy(self._cluster_resolver, devices=devices) 340 self._cfer_fn_cache = weakref.WeakKeyDictionary() 341 self.experimental_enable_get_next_as_optional = True 342 assert isinstance(self._cross_device_ops, 343 cross_device_ops_lib.CollectiveAllReduce) 344 345 def _use_merge_call(self): 346 # We currently only disable merge_call when XLA is used to compile the `fn` 347 # passed to `strategy.run` and all devices are GPU. 348 return not control_flow_util.GraphOrParentsInXlaContext( 349 ops.get_default_graph()) or not all( 350 [_is_gpu_device(d) for d in self._devices]) 351 352 def _initialize_strategy(self, cluster_resolver, devices): 353 # If devices are provided or cluster_spec is not specified, initialize 354 # single worker. Otherwise initialize multi workers. 355 if devices or not cluster_resolver.cluster_spec().as_dict(): 356 self._initialize_local(cluster_resolver, devices=devices) 357 else: 358 self._initialize_multi_worker(cluster_resolver) 359 360 def _initialize_local_devices(self, cluster_resolver, worker_device): 361 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 362 # some cases. 363 if isinstance(cluster_resolver, TFConfigClusterResolver): 364 num_gpus = context.num_gpus() 365 num_tpus = 0 366 else: 367 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 368 num_tpus = cluster_resolver.num_accelerators().get("TPU", 0) 369 370 if num_gpus: 371 local_device_type = "GPU" 372 num_local_devices = num_gpus 373 elif num_tpus: 374 local_device_type = "TPU" 375 num_local_devices = num_tpus 376 else: 377 local_device_type = "CPU" 378 num_local_devices = 1 379 local_devices = tuple( 380 f"{worker_device}/device:{local_device_type}:{i}" 381 for i in range(num_local_devices)) 382 return local_devices, local_device_type 383 384 def _initialize_local(self, cluster_resolver, devices=None): 385 """Initializes the object for local training.""" 386 self._is_chief = True 387 self._num_workers = 1 388 389 if ops.executing_eagerly_outside_functions(): 390 try: 391 context.context().configure_collective_ops( 392 scoped_allocator_enabled_ops=("CollectiveReduce",)) 393 except RuntimeError: 394 logging.warning("Collective ops is not configured at program startup. " 395 "Some performance features may not be enabled.") 396 self._collective_ops_configured = True 397 398 if devices: 399 local_devices = devices 400 if "GPU" in devices[0]: 401 local_device_type = "GPU" 402 elif "TPU" in devices[0]: 403 local_device_type = "TPU" 404 else: 405 local_device_type = "CPU" 406 else: 407 local_devices, local_device_type = self._initialize_local_devices( 408 cluster_resolver, worker_device="") 409 410 self._worker_device = device_util.canonicalize("/device:CPU:0") 411 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 412 413 self._collective_keys = cross_device_utils.CollectiveKeys( 414 group_key_start=1 + self._collective_key_base) 415 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 416 devices=local_devices, 417 group_size=len(local_devices), 418 options=self._communication_options, 419 collective_keys=self._collective_keys) 420 # CrossDeviceOps for per host tensors. 421 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 422 devices=[self._worker_device], 423 group_size=self._num_workers, 424 options=self._communication_options, 425 collective_keys=self._collective_keys) 426 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 427 local_devices) 428 429 self._cluster_spec = None 430 self._task_type = None 431 self._task_id = None 432 self._id_in_cluster = 0 433 434 # This is a mark to tell whether we are running with standalone client or 435 # independent worker. Right now with standalone client, strategy object is 436 # created as local strategy and then turn into multi-worker strategy via 437 # configure call. 438 self._local_or_standalone_client_mode = True 439 440 # Save the num_devices_per_worker and rpc_layer for configure method. 441 self._num_devices_per_worker = len(local_devices) 442 self._local_device_type = local_device_type 443 self._rpc_layer = cluster_resolver.rpc_layer 444 self._warn_nccl_no_gpu() 445 446 logging.info( 447 "Single-worker MultiWorkerMirroredStrategy with local_devices " 448 "= %r, communication = %s", local_devices, 449 self._communication_options.implementation) 450 451 def _initialize_multi_worker(self, cluster_resolver): 452 """Initializes the object for multi-worker training.""" 453 cluster_spec = multi_worker_util.normalize_cluster_spec( 454 cluster_resolver.cluster_spec()) 455 task_type = cluster_resolver.task_type 456 task_id = cluster_resolver.task_id 457 if task_type is None or task_id is None: 458 raise ValueError("When `cluster_spec` is given, you must also specify " 459 "`task_type` and `task_id`.") 460 self._cluster_spec = cluster_spec 461 self._task_type = task_type 462 self._task_id = task_id 463 self._id_in_cluster = multi_worker_util.id_in_cluster( 464 self._cluster_spec, self._task_type, self._task_id) 465 466 self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) 467 if not self._num_workers: 468 raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found " 469 "in `cluster_spec`.") 470 471 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 472 task_id) 473 474 self._worker_device = "/job:%s/task:%d" % (task_type, task_id) 475 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 476 477 if (ops.executing_eagerly_outside_functions() and 478 not getattr(self, "_local_or_standalone_client_mode", False)): 479 context.context().configure_collective_ops( 480 collective_leader=multi_worker_util.collective_leader( 481 cluster_spec, task_type, task_id), 482 scoped_allocator_enabled_ops=("CollectiveReduce",), 483 device_filters=("/job:%s/task:%d" % (task_type, task_id),)) 484 self._collective_ops_configured = True 485 if context.context().coordination_service is None: 486 coordinated_jobs = ["chief", "worker"] 487 if task_type in coordinated_jobs: 488 context.context().configure_coordination_service( 489 service_type="standalone", 490 service_leader=multi_worker_util.coordination_leader( 491 cluster_spec), 492 coordinated_jobs=coordinated_jobs) 493 494 # Starting a std server in eager mode and in independent worker mode. 495 if (context.executing_eagerly() and 496 not getattr(self, "_std_server_started", False) and 497 not getattr(self, "_local_or_standalone_client_mode", False)): 498 # Checking _local_or_standalone_client_mode as well because we should not 499 # create the std server in standalone client mode. 500 config_proto = copy.deepcopy(context.context().config) 501 config_proto = self._update_config_proto(config_proto) 502 503 # If coordination service is enabled, use its internal heartbeat to detect 504 # peer failures instead of the Python-level health check. 505 if config_proto.experimental.coordination_config.service_type: 506 self._enable_check_health = False 507 508 if hasattr(cluster_resolver, "port"): 509 port = cluster_resolver.port 510 else: 511 port = 0 512 server_def = tensorflow_server_pb2.ServerDef( 513 cluster=cluster_spec.as_cluster_def(), 514 default_session_config=config_proto, 515 job_name=task_type, 516 task_index=task_id, 517 protocol=cluster_resolver.rpc_layer or "grpc", 518 port=port) 519 context.context().enable_collective_ops(server_def) 520 self._std_server_started = True 521 # The `ensure_initialized` is needed before calling 522 # `context.context().devices()`. 523 context.context().ensure_initialized() 524 logging.info( 525 "Enabled multi-worker collective ops with available devices: %r", 526 context.context().devices()) 527 528 # TODO(yuefengz): The `num_gpus` is only for this particular task. It 529 # assumes all workers have the same number of GPUs. We should remove this 530 # assumption by querying all tasks for their numbers of GPUs. 531 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 532 # some cases. 533 local_devices, local_device_type = self._initialize_local_devices( 534 cluster_resolver, self._worker_device) 535 if local_device_type == "TPU": 536 tpu_strategy_util.initialize_tpu_system() 537 538 self._collective_keys = cross_device_utils.CollectiveKeys( 539 group_key_start=1 + self._collective_key_base) 540 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 541 devices=local_devices, 542 group_size=len(local_devices) * self._num_workers, 543 options=self._communication_options, 544 collective_keys=self._collective_keys) 545 # CrossDeviceOps for per host tensors. 546 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 547 devices=[self._worker_device], 548 group_size=self._num_workers, 549 options=self._communication_options, 550 collective_keys=self._collective_keys) 551 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 552 local_devices) 553 554 # Add a default device so that ops without specified devices will not end up 555 # on other workers. 556 self._default_device = "/job:%s/task:%d" % (task_type, task_id) 557 558 # Save the num_devices_per_worker and rpc_layer for configure method. 559 self._num_devices_per_worker = len(local_devices) 560 self._local_device_type = local_device_type 561 self._rpc_layer = cluster_resolver.rpc_layer 562 self._warn_nccl_no_gpu() 563 564 if self._enable_check_health and context.executing_eagerly(): 565 self._start_check_health_thread() 566 else: 567 logging.info("Check health not enabled.") 568 569 logging.info( 570 "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " 571 "task_id = %r, num_workers = %r, local_devices = %r, " 572 "communication = %s", cluster_spec.as_dict(), task_type, task_id, 573 self._num_workers, local_devices, 574 self._communication_options.implementation) 575 576 def __del__(self): 577 self._stop_check_health_thread() 578 579 def _input_workers_with_options(self, options=None): 580 host_device = device_util.get_host_for_device(self._worker_device) 581 if not options or options.experimental_fetch_to_device: 582 return input_lib.InputWorkers([(host_device, self.worker_devices)]) 583 else: 584 return input_lib.InputWorkers([( 585 host_device, 586 [device_util.get_host_for_device(worker) for worker in 587 self.worker_devices])]) 588 589 @property 590 def _input_workers(self): 591 return self._input_workers_with_options() 592 593 def _get_variable_creator_initial_value(self, 594 replica_id, 595 device, 596 primary_var, 597 **kwargs): 598 if replica_id == 0: # First replica on each worker. 599 assert device is not None 600 assert primary_var is None 601 602 def initial_value_fn(): # pylint: disable=g-missing-docstring 603 # Only the first device participates in the broadcast of initial values. 604 group_key = self._collective_keys.get_group_key([device]) 605 group_size = self._num_workers 606 collective_instance_key = ( 607 self._collective_keys.get_instance_key(group_key, device)) 608 609 with ops.device(device): 610 initial_value = kwargs["initial_value"] 611 if callable(initial_value): 612 initial_value = initial_value() 613 if isinstance(initial_value, base.CheckpointInitialValue): 614 initial_value = initial_value.wrapped_value 615 assert not callable(initial_value) 616 initial_value = ops.convert_to_tensor( 617 initial_value, dtype=kwargs.get("dtype", None)) 618 619 if self._num_workers > 1: 620 if self._is_chief: 621 bcast_send = collective_ops.broadcast_send( 622 initial_value, initial_value.shape, initial_value.dtype, 623 group_size, group_key, collective_instance_key) 624 with ops.control_dependencies([bcast_send]): 625 return array_ops.identity(initial_value) 626 else: 627 return collective_ops.broadcast_recv(initial_value.shape, 628 initial_value.dtype, 629 group_size, group_key, 630 collective_instance_key) 631 return initial_value 632 633 return initial_value_fn 634 else: 635 return super(CollectiveAllReduceExtended, 636 self)._get_variable_creator_initial_value( 637 replica_id=replica_id, 638 device=device, 639 primary_var=primary_var, 640 **kwargs) 641 642 def _make_input_context(self): 643 input_context = distribute_lib.InputContext( 644 num_input_pipelines=self._num_workers, 645 input_pipeline_id=self._id_in_cluster, 646 num_replicas_in_sync=self._num_replicas_in_sync) 647 return input_context 648 649 def _experimental_distribute_dataset(self, dataset, options): 650 if (options and options.experimental_replication_mode == 651 distribute_lib.InputReplicationMode.PER_REPLICA): 652 raise NotImplementedError( 653 "InputReplicationMode.PER_REPLICA " 654 "is only supported in " 655 "`distribute_datasets_from_function` " 656 "of tf.distribute.MirroredStrategy" 657 ) 658 input_context = self._make_input_context() 659 return input_util.get_distributed_dataset( 660 dataset, 661 self._input_workers_with_options(options), 662 self._container_strategy(), 663 num_replicas_in_sync=self._num_replicas_in_sync, 664 input_context=input_context, 665 options=options) 666 667 def _distribute_datasets_from_function(self, dataset_fn, options): 668 if (options and options.experimental_replication_mode == 669 distribute_lib.InputReplicationMode.PER_REPLICA): 670 raise NotImplementedError( 671 "InputReplicationMode.PER_REPLICA " 672 "is only supported in " 673 "`distribute_datasets_from_function` " 674 "of tf.distribute.MirroredStrategy") 675 input_context = self._make_input_context() 676 return input_util.get_distributed_datasets_from_function( 677 dataset_fn=dataset_fn, 678 input_workers=self._input_workers_with_options(options), 679 input_contexts=[input_context], 680 strategy=self._container_strategy(), 681 options=options) 682 683 def _experimental_distribute_values_from_function(self, value_fn): 684 per_replica_values = [] 685 num_local_replicas = len(self.worker_devices) 686 for local_replica_id in range(num_local_replicas): 687 replica_id = (self._id_in_cluster * num_local_replicas + 688 local_replica_id) 689 value_context = distribute_lib.ValueContext( 690 replica_id, self._num_replicas_in_sync) 691 per_replica_values.append(value_fn(value_context)) 692 return distribute_utils.regroup(per_replica_values, always_wrap=True) 693 694 def _make_dataset_iterator(self, dataset): 695 """Distributes the dataset to each local GPU.""" 696 input_context = self._make_input_context() 697 return input_lib_v1.DatasetIterator( 698 dataset, 699 self._input_workers, 700 self._container_strategy(), 701 num_replicas_in_sync=self._num_replicas_in_sync, 702 input_context=input_context) 703 704 def _make_input_fn_iterator( 705 self, 706 input_fn, 707 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 708 """Distributes the input function to each local GPU.""" 709 input_context = self._make_input_context() 710 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 711 [input_context], 712 self._container_strategy()) 713 714 def _configure(self, 715 session_config=None, 716 cluster_spec=None, 717 task_type=None, 718 task_id=None): 719 """Configures the object. 720 721 Args: 722 session_config: a `tf.compat.v1.ConfigProto` 723 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 724 cluster configurations. 725 task_type: the current task type, such as "worker". 726 task_id: the current task id. 727 728 Raises: 729 ValueError: if `task_type` is not in the `cluster_spec`. 730 """ 731 if cluster_spec: 732 cluster_resolver = SimpleClusterResolver( 733 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 734 task_type=task_type, 735 task_id=task_id, 736 num_accelerators={ 737 self._local_device_type: self._num_devices_per_worker}, 738 rpc_layer=self._rpc_layer) 739 self._initialize_multi_worker(cluster_resolver) 740 assert isinstance(self._cross_device_ops, 741 cross_device_ops_lib.CollectiveAllReduce) 742 743 if session_config: 744 session_config.CopyFrom(self._update_config_proto(session_config)) 745 746 def _update_config_proto(self, config_proto): 747 updated_config = copy.deepcopy(config_proto) 748 # Enable the scoped allocator optimization for CollectiveOps. This 749 # optimization converts many small all-reduces into fewer larger 750 # all-reduces. 751 rewrite_options = updated_config.graph_options.rewrite_options 752 rewrite_options.scoped_allocator_optimization = ( 753 rewriter_config_pb2.RewriterConfig.ON) 754 # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = 755 # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we 756 # clear and then append. 757 del rewrite_options.scoped_allocator_opts.enable_op[:] 758 rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") 759 760 if (not ops.executing_eagerly_outside_functions() and 761 self._communication_options.implementation == 762 collective_util.CommunicationImplementation.NCCL): 763 updated_config.experimental.collective_nccl = True 764 765 if not self._cluster_spec: 766 return updated_config 767 768 assert self._task_type 769 assert self._task_id is not None 770 771 # Collective group leader is needed for collective ops to coordinate 772 # workers. 773 updated_config.experimental.collective_group_leader = ( 774 multi_worker_util.collective_leader(self._cluster_spec, self._task_type, 775 self._task_id)) 776 777 # The device filters prevent communication between workers. 778 del updated_config.device_filters[:] 779 updated_config.device_filters.append( 780 "/job:%s/task:%d" % (self._task_type, self._task_id)) 781 782 return updated_config 783 784 def _get_cross_device_ops(self, value): 785 # CollectiveAllReduce works on a predefined set of devices. In most cases 786 # they should be the compute devices, but certain use cases may reduce host 787 # tensors as well (e.g. early stopping). We infer the cross_device_ops to 788 # use based on the number of devices, since inputs don't always have device 789 # annotations. The compute devices one is preferred since we can potentially 790 # leverage NCCL. 791 if isinstance(value, values.DistributedValues): 792 num_devices = len(value._values) # pylint: disable=protected-access 793 else: 794 num_devices = 1 795 if num_devices == len(self.worker_devices): 796 return self._cross_device_ops 797 else: 798 return self._host_cross_device_ops 799 800 def _gather_to_implementation(self, value, destinations, axis, options): 801 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 802 value, 803 destinations=destinations, 804 axis=axis, 805 options=options) 806 807 def _reduce_to(self, reduce_op, value, destinations, options): 808 if (isinstance(value, values.Mirrored) and 809 reduce_op == reduce_util.ReduceOp.MEAN): 810 return value 811 assert not isinstance(value, values.Mirrored) 812 813 if (isinstance(value, values.DistributedValues) and 814 len(self.worker_devices) == 1): 815 value = value.values[0] 816 817 # When there are multiple workers, we need to reduce across workers using 818 # collective ops. 819 if (not isinstance(value, values.DistributedValues) and 820 self._num_workers == 1): 821 # This function handles reducing values that are not PerReplica or 822 # Mirrored values. For example, the same value could be present on all 823 # replicas in which case `value` would be a single value or value could 824 # be 0. 825 return cross_device_ops_lib.reduce_non_distributed_value( 826 reduce_op, value, destinations, len(self.worker_devices)) 827 return self._get_cross_device_ops(value).reduce( 828 reduce_op, 829 value, 830 destinations=destinations, 831 options=self._communication_options.merge(options)) 832 833 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 834 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" 835 # This implementation avoids using `merge_call` and just launches collective 836 # ops in one replica. 837 if options is None: 838 options = collective_util.Options() 839 840 if context.executing_eagerly(): 841 # In eager mode, falls back to the default implemenation that uses 842 # `merge_call`. Replica functions are running sequentially in eager mode, 843 # and due to the blocking nature of collective ops, execution will hang if 844 # collective ops are to be launched sequentially. 845 return super()._replica_ctx_all_reduce(reduce_op, value, options) 846 847 replica_context = ds_context.get_replica_context() 848 assert replica_context, ( 849 "`StrategyExtended._replica_ctx_all_reduce` must be called in a " 850 "replica context") 851 return self._cross_device_ops._all_reduce( # pylint: disable=protected-access 852 reduce_op, 853 value, 854 replica_context._replica_id, # pylint: disable=protected-access 855 options) 856 857 def _check_health(self): 858 while True: 859 if self._check_health_thread_should_stop.is_set(): 860 return 861 for job in self._cluster_spec.jobs: 862 for task_id in range(self._cluster_spec.num_tasks(job)): 863 peer = "/job:{}/replica:0/task:{}".format(job, task_id) 864 attempts = 0 865 while True: 866 attempts += 1 867 try: 868 context.context().check_collective_ops_peer_health( 869 peer, timeout_in_ms=self._check_health_timeout * 1000) 870 # If check_collective_ops_peer_health doesn't raise an Exception, 871 # the peer is healthy. 872 break 873 except (errors.UnavailableError, errors.FailedPreconditionError, 874 errors.DeadlineExceededError) as e: 875 # TODO(b/151232436): Always raise UnavailableError when a peer 876 # fails. Now there could be many kinds of errors: 877 # - Unavailable: when the peer is not reachable, e.g. it's down. 878 # - FailedPrecondition: when the peer has restarted. 879 if attempts < self._check_health_retry_limit: 880 logging.warning("%s seems down, retrying %d/%d", peer, attempts, 881 self._check_health_retry_limit) 882 continue 883 logging.error( 884 "Cluster check alive failed, %s is down, " 885 "aborting collectives: %s", peer, e) 886 context.context().abort_collective_ops( 887 errors.UNAVAILABLE, 888 "cluster check alive failed, {} is down".format(peer)) 889 return 890 except Exception as e: # pylint: disable=broad-except 891 logging.error("Unexpected exception in check alive: %s", e) 892 context.context().abort_collective_ops( 893 errors.INTERNAL, 894 "unexecpted exception in check alive: %s" % e) 895 return 896 time.sleep(self._check_health_interval) 897 898 def _start_check_health_thread(self): 899 # Use a dummy all-reduce as a barrier to wait for all workers to be up, 900 # otherwise the check health may fail immediately. 901 902 # Use array_ops.identity to create the dummy tensor so that we have a new 903 # Tensor. If we use constant it may be a cached from on a /job:localhost 904 # device, which will cause some code that relies on tensor.device to error. 905 # 906 # TODO(b/151232436): change to an explicit barrier if we have it. 907 dummy_value = array_ops.identity([]) 908 logging.info("Waiting for the cluster, timeout = %s", 909 self._check_health_initial_timeout or "inf") 910 try: 911 self._host_cross_device_ops.reduce( 912 reduce_util.ReduceOp.SUM, 913 dummy_value, 914 dummy_value, 915 options=collective_util.Options( 916 timeout_seconds=self._check_health_initial_timeout, 917 implementation=collective_util.CommunicationImplementation.RING)) 918 if context.is_async(): 919 context.async_wait() 920 except errors.DeadlineExceededError: 921 raise RuntimeError( 922 "Timeout waiting for the cluster, timeout is %d seconds" % 923 self._check_health_initial_timeout) 924 logging.info("Cluster is ready.") 925 self._check_health_thread_should_stop = threading.Event() 926 # Start the thread as daemon to avoid it blocking the program from exiting. 927 # We try best to shutdown the thread but __del__ is not guaranteed to be 928 # called when program exists. 929 self._check_health_thread = threading.Thread( 930 target=self._check_health, 931 daemon=True) 932 self._check_health_thread.start() 933 934 def _stop_check_health_thread(self): 935 if getattr(self, "_check_health_thread", None): 936 logging.info("stopping check health thread") 937 self._check_health_thread_should_stop.set() 938 self._check_health_thread.join() 939 self._check_health_thread = None 940 logging.info("check health thread stopped") 941 942 def _warn_nccl_no_gpu(self): 943 if ((self._communication_options.implementation == 944 collective_util.CommunicationImplementation.NCCL) and 945 self._local_device_type != "GPU"): 946 logging.warning("Enabled NCCL communication but no GPUs detected/" 947 "specified.") 948 949 def _in_multi_worker_mode(self): 950 """Whether this strategy indicates working in multi-worker settings.""" 951 return self._num_workers > 1 952 953 @property 954 def experimental_between_graph(self): 955 return True 956 957 @property 958 def experimental_should_init(self): 959 return True 960 961 @property 962 def should_checkpoint(self): 963 return self._is_chief 964 965 @property 966 def should_save_summary(self): 967 return self._is_chief 968 969 @property 970 def _num_replicas_in_sync(self): 971 return len(self.worker_devices) * self._num_workers 972 973 # TODO(priyag): Delete this once all strategies use global batch size. 974 @property 975 def _global_batch_size(self): 976 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 977 978 `make_input_fn_iterator` assumes per-replica batching. 979 980 Returns: 981 Boolean. 982 """ 983 return True 984 985 def _get_replica_id_in_sync_group(self, replica_id): 986 return self._id_in_cluster * len(self.worker_devices) + replica_id 987 988 def _get_local_replica_id(self, replica_id_in_sync_group): 989 return (replica_id_in_sync_group - 990 self._id_in_cluster * len(self.worker_devices)) 991 992 def __deepcopy__(self, memo): 993 # We check the check health thread instead of whether we are in eager mode 994 # to limit the backward incompatibility. 995 if hasattr(self, "_check_health_thread"): 996 raise ValueError( 997 "MultiWorkerMirroredStrategy cannot be deep copied in eager mode. " 998 "If you're using Estimator and see this error message, call " 999 "tf.compat.v1.disable_eager_execution() at the beginning of your " 1000 "program") 1001 # Otherwise, do a regular deepcopy. 1002 cls = self.__class__ 1003 result = cls.__new__(cls) 1004 memo[id(self)] = result 1005 for k, v in self.__dict__.items(): 1006 setattr(result, k, copy.deepcopy(v, memo)) 1007 return result 1008