1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 16 17import copy 18 19from tensorflow.python import tf2 20from tensorflow.python.distribute import collective_util 21from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 22from tensorflow.python.distribute import cross_device_utils 23from tensorflow.python.distribute import device_util 24from tensorflow.python.distribute import distribute_lib 25from tensorflow.python.distribute import distribute_utils 26from tensorflow.python.distribute import distribution_strategy_context 27from tensorflow.python.distribute import input_lib 28from tensorflow.python.distribute import input_util 29from tensorflow.python.distribute import mirrored_run 30from tensorflow.python.distribute import multi_worker_util 31from tensorflow.python.distribute import numpy_dataset 32from tensorflow.python.distribute import reduce_util 33from tensorflow.python.distribute import values 34from tensorflow.python.distribute import values_util 35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 36from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 37from tensorflow.python.eager import context 38from tensorflow.python.eager import tape 39from tensorflow.python.framework import config 40from tensorflow.python.framework import constant_op 41from tensorflow.python.framework import device as tf_device 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import ops 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import control_flow_ops 46from tensorflow.python.ops import control_flow_util 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.util import nest 49from tensorflow.python.util.tf_export import tf_export 50 51# TODO(josh11b): Replace asserts in this file with if ...: raise ... 52 53 54def _is_device_list_single_worker(devices): 55 """Checks whether the devices list is for single or multi-worker. 56 57 Args: 58 devices: a list of device strings or tf.config.LogicalDevice objects, for 59 either local or for remote devices. 60 61 Returns: 62 a boolean indicating whether these device strings are for local or for 63 remote. 64 65 Raises: 66 ValueError: if device strings are not consistent. 67 """ 68 specs = [] 69 for d in devices: 70 name = d.name if isinstance(d, context.LogicalDevice) else d 71 specs.append(tf_device.DeviceSpec.from_string(name)) 72 num_workers = len({(d.job, d.task, d.replica) for d in specs}) 73 all_local = all(d.job in (None, "localhost") for d in specs) 74 any_local = any(d.job in (None, "localhost") for d in specs) 75 76 if any_local and not all_local: 77 raise ValueError("Local device should have only 'localhost' in the job " 78 "field in device string. " 79 "E.g. 'job:localhost' in " 80 "/job:localhost/replica:0/task:0/device:CPU:0" 81 "Devices cannot have mixed list of device strings " 82 "containing both localhost and other job types such as " 83 "worker, ps etc. ") 84 85 if num_workers == 1 and not all_local: 86 if any(d.task is None for d in specs): 87 raise ValueError("Remote device string must have task specified." 88 "E.g. 'task:0' in " 89 "/job:worker/replica:0/task:0/device:CPU:0") 90 91 return num_workers == 1 92 93 94def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker): 95 """Returns a device list given a cluster spec.""" 96 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 97 devices = [] 98 for task_type in ("chief", "worker"): 99 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): 100 if num_gpus_per_worker == 0: 101 devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id)) 102 else: 103 devices.extend([ 104 "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id) 105 for gpu_id in range(num_gpus_per_worker) 106 ]) 107 return devices 108 109 110def _group_device_list(devices): 111 """Groups the devices list by task_type and task_id. 112 113 Args: 114 devices: a list of device strings for remote devices. 115 116 Returns: 117 a dict of list of device strings mapping from task_type to a list of devices 118 for the task_type in the ascending order of task_id. 119 """ 120 assert not _is_device_list_single_worker(devices) 121 device_dict = {} 122 123 for d in devices: 124 d_spec = tf_device.DeviceSpec.from_string(d) 125 126 # Create an entry for the task_type. 127 if d_spec.job not in device_dict: 128 device_dict[d_spec.job] = [] 129 130 # Fill the device list for task_type until it covers the task_id. 131 while len(device_dict[d_spec.job]) <= d_spec.task: 132 device_dict[d_spec.job].append([]) 133 134 device_dict[d_spec.job][d_spec.task].append(d) 135 136 return device_dict 137 138 139def _is_gpu_device(device): 140 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 141 142 143def _infer_num_gpus_per_worker(devices): 144 """Infers the number of GPUs on each worker. 145 146 Currently to make multi-worker cross device ops work, we need all workers to 147 have the same number of GPUs. 148 149 Args: 150 devices: a list of device strings, can be either local devices or remote 151 devices. 152 153 Returns: 154 number of GPUs per worker. 155 156 Raises: 157 ValueError if workers have different number of GPUs or GPU indices are not 158 consecutive and starting from 0. 159 """ 160 if _is_device_list_single_worker(devices): 161 return sum(1 for d in devices if _is_gpu_device(d)) 162 else: 163 device_dict = _group_device_list(devices) 164 num_gpus = None 165 for _, devices_in_task in device_dict.items(): 166 for device_in_task in devices_in_task: 167 if num_gpus is None: 168 num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d)) 169 170 # Verify other workers have the same number of GPUs. 171 elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)): 172 raise ValueError("All workers should have the same number of GPUs.") 173 174 for d in device_in_task: 175 d_spec = tf_device.DeviceSpec.from_string(d) 176 if (d_spec.device_type == "GPU" and 177 d_spec.device_index >= num_gpus): 178 raise ValueError("GPU `device_index` on a worker should be " 179 "consecutive and start from 0.") 180 return num_gpus 181 182 183def all_local_devices(num_gpus=None): 184 devices = config.list_logical_devices("GPU") 185 if num_gpus is not None: 186 devices = devices[:num_gpus] 187 return devices or config.list_logical_devices("CPU") 188 189 190def all_devices(): 191 devices = [] 192 tfconfig = TFConfigClusterResolver() 193 if tfconfig.cluster_spec().as_dict(): 194 devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(), 195 context.num_gpus()) 196 return devices if devices else all_local_devices() 197 198 199@tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes 200class MirroredStrategy(distribute_lib.Strategy): 201 """Synchronous training across multiple replicas on one machine. 202 203 This strategy is typically used for training on one 204 machine with multiple GPUs. For TPUs, use 205 `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers, 206 please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 207 208 For example, a variable created under a `MirroredStrategy` is a 209 `MirroredVariable`. If no devices are specified in the constructor argument of 210 the strategy then it will use all the available GPUs. If no GPUs are found, it 211 will use the available CPUs. Note that TensorFlow treats all CPUs on a 212 machine as a single device, and uses threads internally for parallelism. 213 214 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 215 >>> with strategy.scope(): 216 ... x = tf.Variable(1.) 217 >>> x 218 MirroredVariable:{ 219 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 220 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 221 } 222 223 While using distribution strategies, all the variable creation should be done 224 within the strategy's scope. This will replicate the variables across all the 225 replicas and keep them in sync using an all-reduce algorithm. 226 227 Variables created inside a `MirroredStrategy` which is wrapped with a 228 `tf.function` are still `MirroredVariables`. 229 230 >>> x = [] 231 >>> @tf.function # Wrap the function with tf.function. 232 ... def create_variable(): 233 ... if not x: 234 ... x.append(tf.Variable(1.)) 235 ... return x[0] 236 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 237 >>> with strategy.scope(): 238 ... _ = create_variable() 239 ... print(x[0]) 240 MirroredVariable:{ 241 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 242 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 243 } 244 245 `experimental_distribute_dataset` can be used to distribute the dataset across 246 the replicas when writing your own training loop. If you are using `.fit` and 247 `.compile` methods available in `tf.keras`, then `tf.keras` will handle the 248 distribution for you. 249 250 For example: 251 252 ```python 253 my_strategy = tf.distribute.MirroredStrategy() 254 with my_strategy.scope(): 255 @tf.function 256 def distribute_train_epoch(dataset): 257 def replica_fn(input): 258 # process input and return result 259 return result 260 261 total_result = 0 262 for x in dataset: 263 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 264 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 265 per_replica_result, axis=None) 266 return total_result 267 268 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 269 for _ in range(EPOCHS): 270 train_result = distribute_train_epoch(dist_dataset) 271 ``` 272 273 Args: 274 devices: a list of device strings such as `['/gpu:0', '/gpu:1']`. If 275 `None`, all available GPUs are used. If no GPUs are found, CPU is used. 276 cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not 277 set, `NcclAllReduce()` will be used by default. One would customize this 278 if NCCL isn't available or if a special implementation that exploits 279 the particular hardware is available. 280 """ 281 282 # Only set this in tests. 283 _collective_key_base = 0 284 285 def __init__(self, devices=None, cross_device_ops=None): 286 extended = MirroredExtended( 287 self, devices=devices, cross_device_ops=cross_device_ops) 288 super(MirroredStrategy, self).__init__(extended) 289 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 290 "MirroredStrategy") 291 292 293@tf_export(v1=["distribute.MirroredStrategy"]) 294class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missing-docstring 295 296 __doc__ = MirroredStrategy.__doc__ 297 298 # Only set this in tests. 299 _collective_key_base = 0 300 301 def __init__(self, devices=None, cross_device_ops=None): 302 extended = MirroredExtended( 303 self, devices=devices, cross_device_ops=cross_device_ops) 304 super(MirroredStrategyV1, self).__init__(extended) 305 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 306 "MirroredStrategy") 307 308 309# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 310class MirroredExtended(distribute_lib.StrategyExtendedV1): 311 """Implementation of MirroredStrategy.""" 312 313 # If this is set to True, use NCCL collective ops instead of NCCL cross device 314 # ops. 315 _prefer_collective_ops = False 316 317 def __init__(self, container_strategy, devices=None, cross_device_ops=None): 318 super(MirroredExtended, self).__init__(container_strategy) 319 if context.executing_eagerly(): 320 if devices and not _is_device_list_single_worker(devices): 321 raise RuntimeError("In-graph multi-worker training with " 322 "`MirroredStrategy` is not supported in eager mode.") 323 else: 324 if TFConfigClusterResolver().cluster_spec().as_dict(): 325 # if you are executing in eager mode, only the single machine code 326 # path is supported. 327 logging.info("Initializing local devices since in-graph multi-worker " 328 "training with `MirroredStrategy` is not supported in " 329 "eager mode. TF_CONFIG will be ignored when " 330 "when initializing `MirroredStrategy`.") 331 devices = devices or all_local_devices() 332 else: 333 devices = devices or all_devices() 334 335 assert devices, ("Got an empty `devices` list and unable to recognize " 336 "any local devices.") 337 self._cross_device_ops = cross_device_ops 338 self._collective_ops_in_use = False 339 self._collective_key_base = container_strategy._collective_key_base 340 self._communication_options = collective_util.Options( 341 implementation=collective_util.CommunicationImplementation.NCCL) 342 self._initialize_strategy(devices) 343 344 # TODO(b/128995245): Enable last partial batch support in graph mode. 345 if ops.executing_eagerly_outside_functions(): 346 self.experimental_enable_get_next_as_optional = True 347 348 # Flag to turn on VariablePolicy. 349 self._use_var_policy = False 350 351 def _use_merge_call(self): 352 # We currently only disable merge_call when XLA is used to compile the `fn` 353 # passed to `strategy.run` and all devices are GPU. 354 return not control_flow_util.GraphOrParentsInXlaContext( 355 ops.get_default_graph()) or not all( 356 [_is_gpu_device(d) for d in self._devices]) 357 358 def _initialize_strategy(self, devices): 359 # The _initialize_strategy method is intended to be used by distribute 360 # coordinator as well. 361 assert devices, "Must specify at least one device." 362 devices = tuple(device_util.resolve(d) for d in devices) 363 assert len(set(devices)) == len(devices), ( 364 "No duplicates allowed in `devices` argument: %s" % (devices,)) 365 if _is_device_list_single_worker(devices): 366 self._initialize_single_worker(devices) 367 self._collective_ops = self._make_collective_ops(devices) 368 if self._prefer_collective_ops and ( 369 isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce) 370 or isinstance(self._inferred_cross_device_ops, 371 cross_device_ops_lib.NcclAllReduce)): 372 self._collective_ops_in_use = True 373 self._inferred_cross_device_ops = None 374 logging.info("Using MirroredStrategy with devices %r", devices) 375 else: 376 self._initialize_multi_worker(devices) 377 378 def _make_collective_ops(self, devices): 379 self._collective_keys = cross_device_utils.CollectiveKeys( 380 group_key_start=1 + self._collective_key_base) 381 return cross_device_ops_lib.CollectiveAllReduce( 382 devices=self._devices, 383 group_size=len(self._devices), 384 options=self._communication_options, 385 collective_keys=self._collective_keys) 386 387 def _initialize_single_worker(self, devices): 388 """Initializes the object for single-worker training.""" 389 self._devices = tuple(device_util.canonicalize(d) for d in devices) 390 self._input_workers_devices = ( 391 (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) 392 393 self._inferred_cross_device_ops = None if self._cross_device_ops else ( 394 cross_device_ops_lib.select_cross_device_ops(devices)) 395 self._host_input_device = numpy_dataset.SingleDevice( 396 self._input_workers_devices[0][0]) 397 self._is_multi_worker_training = False 398 device_spec = tf_device.DeviceSpec.from_string( 399 self._input_workers_devices[0][0]) 400 # Ensures when we enter strategy.scope() we use the correct default device 401 if device_spec.job is not None and device_spec.job != "localhost": 402 self._default_device = "/job:%s/replica:%d/task:%d" % ( 403 device_spec.job, device_spec.replica, device_spec.task) 404 405 def _initialize_multi_worker(self, devices): 406 """Initializes the object for multi-worker training.""" 407 device_dict = _group_device_list(devices) 408 workers = [] 409 worker_devices = [] 410 for job in ("chief", "worker"): 411 for task in range(len(device_dict.get(job, []))): 412 worker = "/job:%s/task:%d" % (job, task) 413 workers.append(worker) 414 worker_devices.append((worker, device_dict[job][task])) 415 416 # Setting `_default_device` will add a device scope in the 417 # distribution.scope. We set the default device to the first worker. When 418 # users specify device under distribution.scope by 419 # with tf.device("/cpu:0"): 420 # ... 421 # their ops will end up on the cpu device of its first worker, e.g. 422 # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. 423 self._default_device = workers[0] 424 self._host_input_device = numpy_dataset.SingleDevice(workers[0]) 425 426 self._devices = tuple(devices) 427 self._input_workers_devices = worker_devices 428 self._is_multi_worker_training = True 429 430 if len(workers) > 1: 431 # Grandfather usage in the legacy tests if they're configured properly. 432 if (not isinstance(self._cross_device_ops, 433 cross_device_ops_lib.ReductionToOneDevice) or 434 self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access 435 raise ValueError( 436 "In-graph multi-worker training with `MirroredStrategy` is not " 437 "supported.") 438 self._inferred_cross_device_ops = self._cross_device_ops 439 else: 440 # TODO(yuefengz): make `select_cross_device_ops` work with device strings 441 # containing job names. 442 self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() 443 444 logging.info("Using MirroredStrategy with remote devices %r", devices) 445 446 def _input_workers_with_options(self, options=None): 447 if not options: 448 return input_lib.InputWorkers(self._input_workers_devices) 449 if (options.experimental_replication_mode == 450 distribute_lib.InputReplicationMode.PER_REPLICA): 451 if options.experimental_place_dataset_on_device: 452 self._input_workers_devices = ( 453 tuple( 454 (device_util.canonicalize(d, d), (d,)) for d in self._devices)) 455 else: 456 self._input_workers_devices = ( 457 tuple((device_util.canonicalize("/device:CPU:0", d), (d,)) 458 for d in self._devices)) 459 return input_lib.InputWorkers(self._input_workers_devices) 460 else: 461 if not options.experimental_fetch_to_device: 462 return input_lib.InputWorkers([ 463 (host_device, (host_device,) * len(compute_devices)) 464 for host_device, compute_devices in self._input_workers_devices 465 ]) 466 else: 467 return input_lib.InputWorkers(self._input_workers_devices) 468 469 @property 470 def _input_workers(self): 471 return self._input_workers_with_options() 472 473 def _get_variable_creator_initial_value(self, 474 replica_id, 475 device, 476 primary_var, 477 **kwargs): 478 """Return the initial value for variables on a replica.""" 479 if replica_id == 0: 480 return kwargs["initial_value"] 481 else: 482 assert primary_var is not None 483 assert device is not None 484 assert kwargs is not None 485 486 def initial_value_fn(): 487 if context.executing_eagerly() or ops.inside_function(): 488 init_value = primary_var.value() 489 return array_ops.identity(init_value) 490 else: 491 with ops.device(device): 492 init_value = primary_var.initial_value 493 return array_ops.identity(init_value) 494 495 return initial_value_fn 496 497 def _create_variable(self, next_creator, **kwargs): 498 """Create a mirrored variable. See `DistributionStrategy.scope`.""" 499 colocate_with = kwargs.pop("colocate_with", None) 500 if colocate_with is None: 501 devices = self._devices 502 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 503 with ops.device(colocate_with.device): 504 return next_creator(**kwargs) 505 else: 506 devices = colocate_with._devices # pylint: disable=protected-access 507 508 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring 509 value_list = [] 510 for i, d in enumerate(devices): 511 with ops.device(d): 512 kwargs["initial_value"] = self._get_variable_creator_initial_value( 513 replica_id=i, 514 device=d, 515 primary_var=value_list[0] if value_list else None, 516 **kwargs) 517 if i > 0: 518 # Give replicas meaningful distinct names: 519 var0name = value_list[0].name.split(":")[0] 520 # We append a / to variable names created on replicas with id > 0 to 521 # ensure that we ignore the name scope and instead use the given 522 # name as the absolute name of the variable. 523 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 524 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 525 # Don't record operations (e.g. other variable reads) during 526 # variable creation. 527 with tape.stop_recording(): 528 v = next_creator(**kwargs) 529 assert not isinstance(v, values.DistributedVariable) 530 value_list.append(v) 531 return value_list 532 533 return distribute_utils.create_mirrored_variable( 534 self._container_strategy(), _real_mirrored_creator, 535 distribute_utils.VARIABLE_CLASS_MAPPING, 536 distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs) 537 538 def _validate_colocate_with_variable(self, colocate_with_variable): 539 distribute_utils.validate_colocate_distributed_variable( 540 colocate_with_variable, self) 541 542 def _make_dataset_iterator(self, dataset): 543 return input_lib_v1.DatasetIterator( 544 dataset, 545 self._input_workers, 546 self._container_strategy(), 547 num_replicas_in_sync=self._num_replicas_in_sync) 548 549 def _make_input_fn_iterator( 550 self, 551 input_fn, 552 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 553 input_contexts = [] 554 num_workers = self._input_workers.num_workers 555 for i in range(num_workers): 556 input_contexts.append(distribute_lib.InputContext( 557 num_input_pipelines=num_workers, 558 input_pipeline_id=i, 559 num_replicas_in_sync=self._num_replicas_in_sync)) 560 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 561 input_contexts, 562 self._container_strategy()) 563 564 def _experimental_distribute_dataset(self, dataset, options): 565 if (options and options.experimental_replication_mode == 566 distribute_lib.InputReplicationMode.PER_REPLICA): 567 raise NotImplementedError( 568 "InputReplicationMode.PER_REPLICA " 569 "is only supported in " 570 "`distribute_datasets_from_function`." 571 ) 572 return input_util.get_distributed_dataset( 573 dataset, 574 self._input_workers_with_options(options), 575 self._container_strategy(), 576 num_replicas_in_sync=self._num_replicas_in_sync, 577 options=options) 578 579 def _experimental_make_numpy_dataset(self, numpy_input, session): 580 return numpy_dataset.one_host_numpy_dataset( 581 numpy_input, self._host_input_device, session) 582 583 def _distribute_datasets_from_function(self, dataset_fn, options): 584 input_workers = self._input_workers_with_options(options) 585 input_contexts = [] 586 num_workers = input_workers.num_workers 587 for i in range(num_workers): 588 input_contexts.append(distribute_lib.InputContext( 589 num_input_pipelines=num_workers, 590 input_pipeline_id=i, 591 num_replicas_in_sync=self._num_replicas_in_sync)) 592 593 return input_util.get_distributed_datasets_from_function( 594 dataset_fn, input_workers, input_contexts, self._container_strategy(), 595 options) 596 597 def _experimental_distribute_values_from_function(self, value_fn): 598 per_replica_values = [] 599 for replica_id in range(self._num_replicas_in_sync): 600 per_replica_values.append(value_fn( 601 distribute_lib.ValueContext(replica_id, 602 self._num_replicas_in_sync))) 603 return distribute_utils.regroup(per_replica_values, always_wrap=True) 604 605 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 606 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 607 initial_loop_values=None): 608 if initial_loop_values is None: 609 initial_loop_values = {} 610 initial_loop_values = nest.flatten(initial_loop_values) 611 612 ctx = input_lib.MultiStepContext() 613 def body(i, *args): 614 """A wrapper around `fn` to create the while loop body.""" 615 del args 616 fn_result = fn(ctx, iterator.get_next()) 617 for (name, output) in ctx.last_step_outputs.items(): 618 # Convert all outputs to tensors, potentially from `DistributedValues`. 619 ctx.last_step_outputs[name] = self._local_results(output) 620 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 621 with ops.control_dependencies([fn_result]): 622 return [i + 1] + flat_last_step_outputs 623 624 # We capture the control_flow_context at this point, before we run `fn` 625 # inside a while_loop. This is useful in cases where we might need to exit 626 # these contexts and get back to the outer context to do some things, for 627 # e.g. create an op which should be evaluated only once at the end of the 628 # loop on the host. One such usage is in creating metrics' value op. 629 self._outer_control_flow_context = ( 630 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 631 632 cond = lambda i, *args: i < iterations 633 i = constant_op.constant(0) 634 loop_result = control_flow_ops.while_loop( 635 cond, body, [i] + initial_loop_values, name="", 636 parallel_iterations=1, back_prop=False, swap_memory=False, 637 return_same_structure=True) 638 del self._outer_control_flow_context 639 640 ctx.run_op = control_flow_ops.group(loop_result) 641 642 # Convert the last_step_outputs from a list to the original dict structure 643 # of last_step_outputs. 644 last_step_tensor_outputs = loop_result[1:] 645 last_step_tensor_outputs_dict = nest.pack_sequence_as( 646 ctx.last_step_outputs, last_step_tensor_outputs) 647 648 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 649 output = last_step_tensor_outputs_dict[name] 650 # For outputs that have already been reduced, wrap them in a Mirrored 651 # container, else in a PerReplica container. 652 if reduce_op is None: 653 last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output) 654 else: 655 assert len(output) == 1 656 last_step_tensor_outputs_dict[name] = output[0] 657 658 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 659 return ctx 660 661 def _broadcast_to(self, tensor, destinations): 662 # This is both a fast path for Python constants, and a way to delay 663 # converting Python values to a tensor until we know what type it 664 # should be converted to. Otherwise we have trouble with: 665 # global_step.assign_add(1) 666 # since the `1` gets broadcast as an int32 but global_step is int64. 667 if isinstance(tensor, (float, int)): 668 return tensor 669 # TODO(josh11b): In eager mode, use one thread per device, or async mode. 670 if not destinations: 671 # TODO(josh11b): Use current logical device instead of 0 here. 672 destinations = self._devices 673 return self._get_cross_device_ops(tensor).broadcast(tensor, destinations) 674 675 def _call_for_each_replica(self, fn, args, kwargs): 676 return mirrored_run.call_for_each_replica( 677 self._container_strategy(), fn, args, kwargs) 678 679 def _configure(self, 680 session_config=None, 681 cluster_spec=None, 682 task_type=None, 683 task_id=None): 684 del task_type, task_id 685 686 if session_config: 687 session_config.CopyFrom(self._update_config_proto(session_config)) 688 689 if cluster_spec: 690 # TODO(yuefengz): remove the following code once cluster_resolver is 691 # added. 692 num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices) 693 multi_worker_devices = _cluster_spec_to_device_list( 694 cluster_spec, num_gpus_per_worker) 695 self._initialize_multi_worker(multi_worker_devices) 696 697 def _update_config_proto(self, config_proto): 698 updated_config = copy.deepcopy(config_proto) 699 updated_config.isolate_session_state = True 700 return updated_config 701 702 def _get_cross_device_ops(self, value): 703 if not self._use_merge_call(): 704 return self._collective_ops 705 706 if self._collective_ops_in_use: 707 if isinstance(value, values.DistributedValues): 708 value_int32 = True in { 709 dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values 710 } 711 else: 712 value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32 713 if value_int32: 714 return cross_device_ops_lib.ReductionToOneDevice() 715 else: 716 return self._collective_ops 717 718 return self._cross_device_ops or self._inferred_cross_device_ops 719 720 def _gather_to_implementation(self, value, destinations, axis, options): 721 if not isinstance(value, values.DistributedValues): 722 # ReductionToOneDevice._gather accepts DistributedValues only. 723 return value 724 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 725 value, 726 destinations=destinations, 727 axis=axis, 728 options=self._communication_options.merge(options)) 729 730 def _reduce_to(self, reduce_op, value, destinations, options): 731 if (distribute_utils.is_mirrored(value) and 732 reduce_op == reduce_util.ReduceOp.MEAN): 733 return value 734 assert not distribute_utils.is_mirrored(value) 735 def get_values(value): 736 if not isinstance(value, values.DistributedValues): 737 # This function handles reducing values that are not PerReplica or 738 # Mirrored values. For example, the same value could be present on all 739 # replicas in which case `value` would be a single value or value could 740 # be 0. 741 return cross_device_ops_lib.reduce_non_distributed_value( 742 reduce_op, value, destinations, self._num_replicas_in_sync) 743 if self._use_merge_call() and self._collective_ops_in_use and (( 744 not cross_device_ops_lib._devices_match(value, destinations) or # pylint: disable=protected-access 745 any("cpu" in d.lower() 746 for d in cross_device_ops_lib.get_devices_from(destinations)))): 747 return cross_device_ops_lib.ReductionToOneDevice().reduce( 748 reduce_op, value, destinations) 749 return self._get_cross_device_ops(value).reduce( 750 reduce_op, 751 value, 752 destinations=destinations, 753 options=self._communication_options.merge(options)) 754 755 return nest.map_structure(get_values, value) 756 757 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 758 cross_device_ops = None 759 for value, _ in value_destination_pairs: 760 if cross_device_ops is None: 761 cross_device_ops = self._get_cross_device_ops(value) 762 elif cross_device_ops is not self._get_cross_device_ops(value): 763 raise ValueError("Inputs to batch_reduce_to must be either all on " 764 "the host or all on the compute devices.") 765 return cross_device_ops.batch_reduce( 766 reduce_op, 767 value_destination_pairs, 768 options=self._communication_options.merge(options)) 769 770 def _update(self, var, fn, args, kwargs, group): 771 # TODO(josh11b): In eager mode, use one thread per device. 772 assert isinstance(var, values.DistributedVariable) 773 updates = [] 774 for i, v in enumerate(var.values): 775 name = "update_%d" % i 776 with ops.device(v.device), \ 777 distribute_lib.UpdateContext(i), \ 778 ops.name_scope(name): 779 # If args and kwargs are not mirrored, the value is returned as is. 780 updates.append( 781 fn(v, *distribute_utils.select_replica(i, args), 782 **distribute_utils.select_replica(i, kwargs))) 783 return distribute_utils.update_regroup(self, updates, group) 784 785 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 786 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" 787 # This implementation avoids using `merge_call` and just launches collective 788 # ops in one replica. 789 if options is None: 790 options = collective_util.Options() 791 792 if context.executing_eagerly() or ( 793 not tf2.enabled()) or self._use_merge_call(): 794 # In eager mode, falls back to the default implementation that uses 795 # `merge_call`. Replica functions are running sequentially in eager mode, 796 # and due to the blocking nature of collective ops, execution will hang if 797 # collective ops are to be launched sequentially. 798 return super()._replica_ctx_all_reduce(reduce_op, value, options) 799 800 replica_context = distribution_strategy_context.get_replica_context() 801 assert replica_context, ( 802 "`StrategyExtended._replica_ctx_all_reduce` must be called in a " 803 "replica context") 804 return self._get_cross_device_ops(value)._all_reduce( # pylint: disable=protected-access 805 reduce_op, 806 value, 807 replica_context._replica_id, # pylint: disable=protected-access 808 options) 809 810 def _replica_ctx_update(self, var, fn, args, kwargs, group): 811 if self._use_merge_call(): 812 return super()._replica_ctx_update(var, fn, args, kwargs, group) 813 814 replica_context = distribution_strategy_context.get_replica_context() 815 assert replica_context 816 replica_id = values_util.get_current_replica_id_as_int() 817 name = "update_%d" % replica_id 818 819 if isinstance(var, values.DistributedVariable): 820 var = var._get_replica(replica_id) # pylint: disable=protected-access 821 822 with ops.device(var.device), ops.name_scope(name): 823 result = fn(var, *args, **kwargs) 824 return result 825 826 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 827 assert isinstance(colocate_with, tuple) 828 # TODO(josh11b): In eager mode, use one thread per device. 829 updates = [] 830 for i, d in enumerate(colocate_with): 831 name = "update_%d" % i 832 with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): 833 updates.append( 834 fn(*distribute_utils.select_replica(i, args), 835 **distribute_utils.select_replica(i, kwargs))) 836 return distribute_utils.update_regroup(self, updates, group) 837 838 def read_var(self, replica_local_var): 839 """Read the aggregate value of a replica-local variable.""" 840 # pylint: disable=protected-access 841 if distribute_utils.is_sync_on_read(replica_local_var): 842 return replica_local_var._get_cross_replica() 843 assert distribute_utils.is_mirrored(replica_local_var) 844 return array_ops.identity(replica_local_var._get()) 845 # pylint: enable=protected-access 846 847 def value_container(self, val): 848 return distribute_utils.value_container(val) 849 850 @property 851 def _num_replicas_in_sync(self): 852 return len(self._devices) 853 854 @property 855 def worker_devices(self): 856 return self._devices 857 858 @property 859 def worker_devices_by_replica(self): 860 return [[d] for d in self._devices] 861 862 @property 863 def parameter_devices(self): 864 return self.worker_devices 865 866 @property 867 def experimental_between_graph(self): 868 return False 869 870 @property 871 def experimental_should_init(self): 872 return True 873 874 @property 875 def should_checkpoint(self): 876 return True 877 878 @property 879 def should_save_summary(self): 880 return True 881 882 def non_slot_devices(self, var_list): 883 del var_list 884 # TODO(josh11b): Should this be the last logical device instead? 885 return self._devices 886 887 # TODO(priyag): Delete this once all strategies use global batch size. 888 @property 889 def _global_batch_size(self): 890 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 891 892 `make_input_fn_iterator` assumes per-replica batching. 893 894 Returns: 895 Boolean. 896 """ 897 return True 898 899 def _in_multi_worker_mode(self): 900 """Whether this strategy indicates working in multi-worker settings.""" 901 return False 902 903 def _get_local_replica_id(self, replica_id_in_sync_group): 904 return replica_id_in_sync_group 905 906 def _get_replica_id_in_sync_group(self, replica_id): 907 return replica_id 908