1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU Strategy.""" 16 17import atexit 18import collections 19import contextlib 20import copy 21import functools 22import weakref 23 24from absl import logging 25import numpy as np 26 27from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 28from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 29from tensorflow.python.autograph.impl import api as autograph 30from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 31from tensorflow.python.distribute import device_util 32from tensorflow.python.distribute import distribute_lib 33from tensorflow.python.distribute import distribute_utils 34from tensorflow.python.distribute import input_lib 35from tensorflow.python.distribute import input_util 36from tensorflow.python.distribute import numpy_dataset 37from tensorflow.python.distribute import reduce_util 38from tensorflow.python.distribute import tpu_replicated_variable 39from tensorflow.python.distribute import tpu_util 40from tensorflow.python.distribute import tpu_values 41from tensorflow.python.distribute import values 42from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 43from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 44from tensorflow.python.eager import context 45from tensorflow.python.eager import def_function 46from tensorflow.python.eager import function 47from tensorflow.python.framework import constant_op 48from tensorflow.python.framework import device_spec 49from tensorflow.python.framework import dtypes 50from tensorflow.python.framework import indexed_slices 51from tensorflow.python.framework import ops 52from tensorflow.python.framework import sparse_tensor 53from tensorflow.python.framework import tensor_shape 54from tensorflow.python.framework import tensor_util 55from tensorflow.python.ops import array_ops 56from tensorflow.python.ops import control_flow_ops 57from tensorflow.python.ops import math_ops 58from tensorflow.python.ops import resource_variable_ops 59from tensorflow.python.ops import variables as variables_lib 60from tensorflow.python.ops.ragged import ragged_tensor 61from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import 62from tensorflow.python.tpu import tpu 63from tensorflow.python.tpu import tpu_hardware_feature 64from tensorflow.python.tpu import tpu_strategy_util 65from tensorflow.python.tpu import training_loop 66from tensorflow.python.tpu.ops import tpu_ops 67from tensorflow.python.util import deprecation 68from tensorflow.python.util import nest 69from tensorflow.python.util import tf_inspect 70from tensorflow.python.util.tf_export import tf_export 71 72 73_XLA_OP_BY_OP_INPUTS_LIMIT = 200 74 75 76@contextlib.contextmanager 77def maybe_init_scope(): 78 if ops.executing_eagerly_outside_functions(): 79 yield 80 else: 81 with ops.init_scope(): 82 yield 83 84 85def validate_run_function(fn): 86 """Validate the function passed into strategy.run.""" 87 88 # We allow three types of functions/objects passed into TPUStrategy 89 # run in eager mode: 90 # 1. a user annotated tf.function 91 # 2. a ConcreteFunction, this is mostly what you get from loading a saved 92 # model. 93 # 3. a callable object and the `__call__` method itself is a tf.function. 94 # 95 # Otherwise we return an error, because we don't support eagerly running 96 # run in TPUStrategy. 97 98 if context.executing_eagerly() \ 99 and not isinstance(fn, def_function.Function) \ 100 and not isinstance(fn, function.ConcreteFunction) \ 101 and not (callable(fn) and isinstance(fn.__call__, def_function.Function)): 102 raise NotImplementedError( 103 "TPUStrategy.run(fn, ...) does not support pure eager " 104 "execution. please make sure the function passed into " 105 "`strategy.run` is a `tf.function` or " 106 "`strategy.run` is called inside a `tf.function` if " 107 "eager behavior is enabled.") 108 109 110def _maybe_partial_apply_variables(fn, args, kwargs): 111 """Inspects arguments to partially apply any DistributedVariable. 112 113 This avoids an automatic cast of the current variable value to tensor. 114 115 Note that a variable may be captured implicitly with Python scope instead of 116 passing it to run(), but supporting run() keeps behavior consistent 117 with MirroredStrategy. 118 119 Since positional arguments must be applied from left to right, this function 120 does some tricky function inspection to move variable positional arguments 121 into kwargs. As a result of this, we can't support passing Variables as *args, 122 nor as args to functions which combine both explicit positional arguments and 123 *args. 124 125 Args: 126 fn: The function to run, as passed to run(). 127 args: Positional arguments to fn, as passed to run(). 128 kwargs: Keyword arguments to fn, as passed to run(). 129 130 Returns: 131 A tuple of the function (possibly wrapped), args, kwargs (both 132 possibly filtered, with members of args possibly moved to kwargs). 133 If no variables are found, this function is a noop. 134 135 Raises: 136 ValueError: If the function signature makes unsupported use of *args, or if 137 too many arguments are passed. 138 """ 139 140 def is_distributed_var(x): 141 flat = nest.flatten(x) 142 return flat and isinstance(flat[0], values.DistributedVariable) 143 144 # We will split kwargs into two dicts, one of which will be applied now. 145 var_kwargs = {} 146 nonvar_kwargs = {} 147 148 if kwargs: 149 var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)} 150 if var_kwargs: 151 nonvar_kwargs = { 152 k: v for k, v in kwargs.items() if not is_distributed_var(v) 153 } 154 155 # Dump the argument names of `fn` to a list. This will include both positional 156 # and keyword arguments, but since positional arguments come first we can 157 # look up names of positional arguments by index. 158 positional_args = [] 159 index_of_star_args = None 160 for i, p in enumerate(tf_inspect.signature(fn).parameters.values()): 161 # Class methods define "self" as first argument, but we don't pass "self". 162 # Note that this is a heuristic, as a method can name its first argument 163 # something else, and a function can define a first argument "self" as well. 164 # In both of these cases, using a Variable will fail with an unfortunate 165 # error about the number of arguments. 166 # inspect.is_method() seems not to work here, possibly due to the use of 167 # tf.function(). 168 if i == 0 and p.name == "self": 169 continue 170 171 if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD: 172 positional_args.append(p.name) 173 174 elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL: 175 # We'll raise an error later if a variable is passed to *args, since we 176 # can neither pass it by name nor partially apply it. This case only 177 # happens once at most. 178 index_of_star_args = i 179 180 elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY: 181 # This is a rare Python feature, indicating a / in the arg list. 182 if var_kwargs or any(is_distributed_var(a) for a in args): 183 raise ValueError( 184 "Mixing Variables and positional-only parameters not supported by " 185 f"TPUStrategy. Received {len(var_kwargs)} DistributedVariables in " 186 f"**kwargs and {sum(is_distributed_var(a) for a in args)} in *args," 187 " expected zero for both." 188 ) 189 return fn, args, kwargs 190 191 star_args = [] 192 have_seen_var_arg = False 193 194 for i, a in enumerate(args): 195 if is_distributed_var(a): 196 if index_of_star_args is not None and i >= index_of_star_args: 197 raise ValueError( 198 "TPUStrategy.run() cannot handle Variables passed to *args. " 199 "Either name the function argument, or capture the Variable " 200 "implicitly.") 201 if len(positional_args) <= i: 202 raise ValueError( 203 "Too many positional arguments passed to call to TPUStrategy.run()." 204 ) 205 var_kwargs[positional_args[i]] = a 206 have_seen_var_arg = True 207 else: 208 if index_of_star_args is not None and i >= index_of_star_args: 209 if have_seen_var_arg: 210 raise ValueError( 211 "TPUStrategy.run() cannot handle both Variables and a mix of " 212 "positional args and *args. Either remove the *args, or capture " 213 "the Variable implicitly.") 214 else: 215 star_args.append(a) 216 continue 217 218 if len(positional_args) <= i: 219 raise ValueError( 220 "Too many positional arguments passed to call to TPUStrategy.run()." 221 ) 222 nonvar_kwargs[positional_args[i]] = a 223 224 if var_kwargs: 225 return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs 226 return fn, args, kwargs 227 228 229@tf_export("distribute.TPUStrategy", v1=[]) 230class TPUStrategyV2(distribute_lib.Strategy): 231 """Synchronous training on TPUs and TPU Pods. 232 233 To construct a TPUStrategy object, you need to run the 234 initialization code as below: 235 236 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 237 >>> tf.config.experimental_connect_to_cluster(resolver) 238 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 239 >>> strategy = tf.distribute.TPUStrategy(resolver) 240 241 While using distribution strategies, the variables created within the 242 strategy's scope will be replicated across all the replicas and can be kept in 243 sync using all-reduce algorithms. 244 245 To run TF2 programs on TPUs, you can either use `.compile` and 246 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 247 training loop by calling `strategy.run` directly. Note that 248 TPUStrategy doesn't support pure eager execution, so please make sure the 249 function passed into `strategy.run` is a `tf.function` or 250 `strategy.run` is called inside a `tf.function` if eager 251 behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu. 252 253 `distribute_datasets_from_function` and 254 `experimental_distribute_dataset` APIs can be used to distribute the dataset 255 across the TPU workers when writing your own training loop. If you are using 256 `fit` and `compile` methods available in `tf.keras.Model`, then Keras will 257 handle the distribution for you. 258 259 An example of writing customized training loop on TPUs: 260 261 >>> with strategy.scope(): 262 ... model = tf.keras.Sequential([ 263 ... tf.keras.layers.Dense(2, input_shape=(5,)), 264 ... ]) 265 ... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 266 267 >>> def dataset_fn(ctx): 268 ... x = np.random.random((2, 5)).astype(np.float32) 269 ... y = np.random.randint(2, size=(2, 1)) 270 ... dataset = tf.data.Dataset.from_tensor_slices((x, y)) 271 ... return dataset.repeat().batch(1, drop_remainder=True) 272 >>> dist_dataset = strategy.distribute_datasets_from_function( 273 ... dataset_fn) 274 >>> iterator = iter(dist_dataset) 275 276 >>> @tf.function() 277 ... def train_step(iterator): 278 ... 279 ... def step_fn(inputs): 280 ... features, labels = inputs 281 ... with tf.GradientTape() as tape: 282 ... logits = model(features, training=True) 283 ... loss = tf.keras.losses.sparse_categorical_crossentropy( 284 ... labels, logits) 285 ... 286 ... grads = tape.gradient(loss, model.trainable_variables) 287 ... optimizer.apply_gradients(zip(grads, model.trainable_variables)) 288 ... 289 ... strategy.run(step_fn, args=(next(iterator),)) 290 291 >>> train_step(iterator) 292 293 For the advanced use cases like model parallelism, you can set 294 `experimental_device_assignment` argument when creating TPUStrategy to specify 295 number of replicas and number of logical devices. Below is an example to 296 initialize TPU system with 2 logical devices and 1 replica. 297 298 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 299 >>> tf.config.experimental_connect_to_cluster(resolver) 300 >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver) 301 >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build( 302 ... topology, 303 ... computation_shape=[1, 1, 1, 2], 304 ... num_replicas=1) 305 >>> strategy = tf.distribute.TPUStrategy( 306 ... resolver, experimental_device_assignment=device_assignment) 307 308 Then you can run a `tf.add` operation only on logical device 0. 309 310 >>> @tf.function() 311 ... def step_fn(inputs): 312 ... features, _ = inputs 313 ... output = tf.add(features, features) 314 ... 315 ... # Add operation will be executed on logical device 0. 316 ... output = strategy.experimental_assign_to_logical_device(output, 0) 317 ... return output 318 >>> dist_dataset = strategy.distribute_datasets_from_function( 319 ... dataset_fn) 320 >>> iterator = iter(dist_dataset) 321 >>> strategy.run(step_fn, args=(next(iterator),)) 322 323 `experimental_spmd_xla_partitioning` enables the experimental XLA SPMD feature 324 for model parallelism. This flag can reduce the compilation time and HBM 325 requirements. When running in this mode, every input tensor must either be 326 partitioned (via `strategy.experimental_split_to_logical_devices`) or fully 327 replicated (via `strategy.experimental_replicate_to_logical_devices`) to all 328 logical devices. And calling `strategy.experimental_assign_to_logical_device` 329 will result in a ValueError in this mode. 330 """ 331 332 def __init__(self, 333 tpu_cluster_resolver=None, 334 experimental_device_assignment=None, 335 experimental_spmd_xla_partitioning=False): 336 """Synchronous training in TPU donuts or Pods. 337 338 Args: 339 tpu_cluster_resolver: A 340 `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which 341 provides information about the TPU cluster. If None, it will assume 342 running on a local TPU worker. 343 experimental_device_assignment: Optional 344 `tf.tpu.experimental.DeviceAssignment` to specify the placement of 345 replicas on the TPU cluster. 346 experimental_spmd_xla_partitioning: If True, enable the SPMD (Single 347 Program Multiple Data) mode in XLA compiler. This flag only affects the 348 performance of XLA compilation and the HBM requirement of the compiled 349 TPU program. Ceveat: if this flag is True, calling 350 `tf.distribute.TPUStrategy.experimental_assign_to_logical_device` will 351 result in a ValueError. 352 """ 353 super(TPUStrategyV2, self).__init__( 354 TPUExtended( 355 self, 356 tpu_cluster_resolver, 357 device_assignment=experimental_device_assignment, 358 use_spmd_for_xla_partitioning=experimental_spmd_xla_partitioning)) 359 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 360 distribute_lib.distribution_strategy_replica_gauge.get_cell( 361 "num_workers").set(self.extended.num_hosts) 362 distribute_lib.distribution_strategy_replica_gauge.get_cell( 363 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 364 # Packed variable is used to reduce the overhead of function execution. 365 # For a DistributedVariable, only one variable handle is captured into a 366 # function graph. It's only supported in eager mode. 367 # Packed variable is currently not supported when SPMD is enabled. 368 # TODO(b/202047549): enable Packed variable in SPMD mode. 369 self._enable_packed_variable_in_eager_mode = ( 370 not experimental_spmd_xla_partitioning) 371 372 def run(self, fn, args=(), kwargs=None, options=None): 373 """Run the computation defined by `fn` on each TPU replica. 374 375 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 376 `tf.distribute.DistributedValues`, such as those produced by a 377 `tf.distribute.DistributedDataset` from 378 `tf.distribute.Strategy.experimental_distribute_dataset` or 379 `tf.distribute.Strategy.distribute_datasets_from_function`, 380 when `fn` is executed on a particular replica, it will be executed with the 381 component of `tf.distribute.DistributedValues` that correspond to that 382 replica. 383 384 `fn` may call `tf.distribute.get_replica_context()` to access members such 385 as `all_reduce`. 386 387 All arguments in `args` or `kwargs` should either be nest of tensors or 388 `tf.distribute.DistributedValues` containing tensors or composite tensors. 389 390 Example usage: 391 392 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 393 >>> tf.config.experimental_connect_to_cluster(resolver) 394 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 395 >>> strategy = tf.distribute.TPUStrategy(resolver) 396 >>> @tf.function 397 ... def run(): 398 ... def value_fn(value_context): 399 ... return value_context.num_replicas_in_sync 400 ... distributed_values = ( 401 ... strategy.experimental_distribute_values_from_function(value_fn)) 402 ... def replica_fn(input): 403 ... return input * 2 404 ... return strategy.run(replica_fn, args=(distributed_values,)) 405 >>> result = run() 406 407 Args: 408 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 409 args: (Optional) Positional arguments to `fn`. 410 kwargs: (Optional) Keyword arguments to `fn`. 411 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 412 the options to run `fn`. 413 414 Returns: 415 Merged return value of `fn` across replicas. The structure of the return 416 value is the same as the return value from `fn`. Each element in the 417 structure can either be `tf.distribute.DistributedValues`, `Tensor` 418 objects, or `Tensor`s (for example, if running on a single replica). 419 """ 420 validate_run_function(fn) 421 422 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 423 424 # Note: the target function is converted to graph even when in Eager mode, 425 # so autograph is on by default here. 426 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 427 options = options or distribute_lib.RunOptions() 428 return self.extended.tpu_run(fn, args, kwargs, options) 429 430 @property 431 def cluster_resolver(self): 432 """Returns the cluster resolver associated with this strategy. 433 434 `tf.distribute.TPUStrategy` provides the associated 435 `tf.distribute.cluster_resolver.ClusterResolver`. If the user provides one 436 in `__init__`, that instance is returned; if the user does not, a default 437 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided. 438 """ 439 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access 440 441 def experimental_assign_to_logical_device(self, tensor, logical_device_id): 442 """Adds annotation that `tensor` will be assigned to a logical device. 443 444 This adds an annotation to `tensor` specifying that operations on 445 `tensor` will be invoked on logical core device id `logical_device_id`. 446 When model parallelism is used, the default behavior is that all ops 447 are placed on zero-th logical device. 448 449 ```python 450 451 # Initializing TPU system with 2 logical devices and 4 replicas. 452 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 453 tf.config.experimental_connect_to_cluster(resolver) 454 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 455 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 456 topology, 457 computation_shape=[1, 1, 1, 2], 458 num_replicas=4) 459 strategy = tf.distribute.TPUStrategy( 460 resolver, experimental_device_assignment=device_assignment) 461 iterator = iter(inputs) 462 463 @tf.function() 464 def step_fn(inputs): 465 output = tf.add(inputs, inputs) 466 467 # Add operation will be executed on logical device 0. 468 output = strategy.experimental_assign_to_logical_device(output, 0) 469 return output 470 471 strategy.run(step_fn, args=(next(iterator),)) 472 ``` 473 474 Args: 475 tensor: Input tensor to annotate. 476 logical_device_id: Id of the logical core to which the tensor will be 477 assigned. 478 479 Raises: 480 ValueError: The logical device id presented is not consistent with total 481 number of partitions specified by the device assignment or the TPUStrategy 482 is constructed with `experimental_spmd_xla_partitioning=True`. 483 484 Returns: 485 Annotated tensor with identical value as `tensor`. 486 """ 487 if self.extended._use_spmd_for_xla_partitioning: # pylint: disable=protected-access 488 raise ValueError( 489 "Cannot assign a tensor to a logical device in SPMD mode. To disable " 490 "SPMD, Please construct the TPUStrategy with " 491 "`experimental_spmd_xla_partitioning=False`") 492 493 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 494 if (logical_device_id < 0 or 495 logical_device_id >= num_logical_devices_per_replica): 496 raise ValueError("`logical_core_id` to assign must be lower then total " 497 "number of logical devices per replica. Received " 498 "logical device id {} but there are only total of {} " 499 "logical devices in replica.".format( 500 logical_device_id, num_logical_devices_per_replica)) 501 return xla_sharding.assign_device( 502 tensor, logical_device_id, use_sharding_op=True) 503 504 def experimental_split_to_logical_devices(self, tensor, partition_dimensions): 505 """Adds annotation that `tensor` will be split across logical devices. 506 507 This adds an annotation to tensor `tensor` specifying that operations on 508 `tensor` will be split among multiple logical devices. Tensor `tensor` will 509 be split across dimensions specified by `partition_dimensions`. 510 The dimensions of `tensor` must be divisible by corresponding value in 511 `partition_dimensions`. 512 513 For example, for system with 8 logical devices, if `tensor` is an image 514 tensor with shape (batch_size, width, height, channel) and 515 `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split 516 2 in width dimension and 4 way in height dimension and the split 517 tensor values will be fed into 8 logical devices. 518 519 ```python 520 # Initializing TPU system with 8 logical devices and 1 replica. 521 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 522 tf.config.experimental_connect_to_cluster(resolver) 523 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 524 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 525 topology, 526 computation_shape=[1, 2, 2, 2], 527 num_replicas=1) 528 # Construct the TPUStrategy. Since we are going to split the image across 529 # logical devices, here we set `experimental_spmd_xla_partitioning=True` 530 # so that the partitioning can be compiled in SPMD mode, which usually 531 # results in faster compilation and smaller HBM requirement if the size of 532 # input and activation tensors are much bigger than that of the model 533 # parameters. Note that this flag is suggested but not a hard requirement 534 # for `experimental_split_to_logical_devices`. 535 strategy = tf.distribute.TPUStrategy( 536 resolver, experimental_device_assignment=device_assignment, 537 experimental_spmd_xla_partitioning=True) 538 539 iterator = iter(inputs) 540 541 @tf.function() 542 def step_fn(inputs): 543 inputs = strategy.experimental_split_to_logical_devices( 544 inputs, [1, 2, 4, 1]) 545 546 # model() function will be executed on 8 logical devices with `inputs` 547 # split 2 * 4 ways. 548 output = model(inputs) 549 return output 550 551 strategy.run(step_fn, args=(next(iterator),)) 552 ``` 553 Args: 554 tensor: Input tensor to annotate. 555 partition_dimensions: An unnested list of integers with the size equal to 556 rank of `tensor` specifying how `tensor` will be partitioned. The 557 product of all elements in `partition_dimensions` must be equal to the 558 total number of logical devices per replica. 559 560 Raises: 561 ValueError: 1) If the size of partition_dimensions does not equal to rank 562 of `tensor` or 2) if product of elements of `partition_dimensions` does 563 not match the number of logical devices per replica defined by the 564 implementing DistributionStrategy's device specification or 565 3) if a known size of `tensor` is not divisible by corresponding 566 value in `partition_dimensions`. 567 568 Returns: 569 Annotated tensor with identical value as `tensor`. 570 """ 571 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 572 num_partition_splits = np.prod(partition_dimensions) 573 input_shape = tensor.shape 574 tensor_rank = len(input_shape) 575 576 if tensor_rank != len(partition_dimensions): 577 raise ValueError("Length of `partition_dimensions` must equal to the " 578 "rank of `tensor.shape` ({}). Received " 579 "len(partition_dimensions)={}.".format( 580 tensor_rank, len(partition_dimensions))) 581 582 for dim_index, dim_size in enumerate(input_shape): 583 if dim_size is None: 584 continue 585 586 split_size = partition_dimensions[dim_index] 587 if dim_size % split_size != 0: 588 raise ValueError("Tensor shape at `partition_dimensions[{}]` must be " 589 "divisible by corresponding value specified " 590 "by `partition_dimensions` ({}). Received: {}.".format( 591 dim_index, split_size, dim_size)) 592 593 if num_partition_splits != num_logical_devices_per_replica: 594 raise ValueError( 595 "The product of `partition_dimensions` should be the same as the " 596 "number of logical devices (={}). Received `partition_dimensions`={}," 597 "and their product is {}.".format(num_logical_devices_per_replica, 598 partition_dimensions, 599 num_partition_splits)) 600 601 tile_assignment = np.arange(num_partition_splits).reshape( 602 partition_dimensions) 603 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) 604 605 def experimental_replicate_to_logical_devices(self, tensor): 606 """Adds annotation that `tensor` will be replicated to all logical devices. 607 608 This adds an annotation to tensor `tensor` specifying that operations on 609 `tensor` will be invoked on all logical devices. 610 611 ```python 612 # Initializing TPU system with 2 logical devices and 4 replicas. 613 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 614 tf.config.experimental_connect_to_cluster(resolver) 615 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 616 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 617 topology, 618 computation_shape=[1, 1, 1, 2], 619 num_replicas=4) 620 strategy = tf.distribute.TPUStrategy( 621 resolver, experimental_device_assignment=device_assignment) 622 623 iterator = iter(inputs) 624 625 @tf.function() 626 def step_fn(inputs): 627 images, labels = inputs 628 images = strategy.experimental_split_to_logical_devices( 629 inputs, [1, 2, 4, 1]) 630 631 # model() function will be executed on 8 logical devices with `inputs` 632 # split 2 * 4 ways. 633 output = model(inputs) 634 635 # For loss calculation, all logical devices share the same logits 636 # and labels. 637 labels = strategy.experimental_replicate_to_logical_devices(labels) 638 output = strategy.experimental_replicate_to_logical_devices(output) 639 loss = loss_fn(labels, output) 640 641 return loss 642 643 strategy.run(step_fn, args=(next(iterator),)) 644 ``` 645 Args: 646 tensor: Input tensor to annotate. 647 648 Returns: 649 Annotated tensor with identical value as `tensor`. 650 """ 651 return xla_sharding.replicate(tensor, use_sharding_op=True) 652 653 654@tf_export("distribute.experimental.TPUStrategy", v1=[]) 655@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy") 656class TPUStrategy(distribute_lib.Strategy): 657 """Synchronous training on TPUs and TPU Pods. 658 659 To construct a TPUStrategy object, you need to run the 660 initialization code as below: 661 662 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 663 >>> tf.config.experimental_connect_to_cluster(resolver) 664 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 665 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 666 667 While using distribution strategies, the variables created within the 668 strategy's scope will be replicated across all the replicas and can be kept in 669 sync using all-reduce algorithms. 670 671 To run TF2 programs on TPUs, you can either use `.compile` and 672 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 673 training loop by calling `strategy.run` directly. Note that 674 TPUStrategy doesn't support pure eager execution, so please make sure the 675 function passed into `strategy.run` is a `tf.function` or 676 `strategy.run` is called inside a `tf.function` if eager 677 behavior is enabled. 678 """ 679 680 def __init__(self, 681 tpu_cluster_resolver=None, 682 device_assignment=None): 683 """Synchronous training in TPU donuts or Pods. 684 685 Args: 686 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 687 which provides information about the TPU cluster. 688 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 689 specify the placement of replicas on the TPU cluster. 690 """ 691 logging.warning( 692 "`tf.distribute.experimental.TPUStrategy` is deprecated, please use " 693 " the non experimental symbol `tf.distribute.TPUStrategy` instead.") 694 695 super(TPUStrategy, self).__init__( 696 TPUExtended( 697 self, tpu_cluster_resolver, device_assignment=device_assignment)) 698 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 699 distribute_lib.distribution_strategy_replica_gauge.get_cell( 700 "num_workers").set(self.extended.num_hosts) 701 distribute_lib.distribution_strategy_replica_gauge.get_cell( 702 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 703 # Packed variable is used to reduce the overhead of function execution. 704 # For a DistributedVariable, only one variable handle is captured into a 705 # function graph. It's only supported in eager mode. 706 self._enable_packed_variable_in_eager_mode = True 707 708 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 709 # can use the default implementation. 710 # This implementation runs a single step. It does not use infeed or outfeed. 711 def run(self, fn, args=(), kwargs=None, options=None): 712 """See base class.""" 713 validate_run_function(fn) 714 715 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 716 717 # Note: the target function is converted to graph even when in Eager mode, 718 # so autograph is on by default here. 719 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 720 options = options or distribute_lib.RunOptions() 721 return self.extended.tpu_run(fn, args, kwargs, options) 722 723 @property 724 def cluster_resolver(self): 725 """Returns the cluster resolver associated with this strategy. 726 727 `tf.distribute.experimental.TPUStrategy` provides the 728 associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user 729 provides one in `__init__`, that instance is returned; if the user does 730 not, a default 731 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided. 732 """ 733 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access 734 735 736@tf_export(v1=["distribute.experimental.TPUStrategy"]) 737class TPUStrategyV1(distribute_lib.StrategyV1): 738 """TPU distribution strategy implementation.""" 739 740 def __init__(self, 741 tpu_cluster_resolver=None, 742 steps_per_run=None, 743 device_assignment=None): 744 """Initializes the TPUStrategy object. 745 746 Args: 747 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 748 which provides information about the TPU cluster. 749 steps_per_run: Number of steps to run on device before returning to the 750 host. Note that this can have side-effects on performance, hooks, 751 metrics, summaries etc. 752 This parameter is only used when Distribution Strategy is used with 753 estimator or keras. 754 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 755 specify the placement of replicas on the TPU cluster. Currently only 756 supports the usecase of using a single core within a TPU cluster. 757 """ 758 super(TPUStrategyV1, self).__init__(TPUExtended( 759 self, tpu_cluster_resolver, steps_per_run, device_assignment)) 760 distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy") 761 distribute_lib.distribution_strategy_replica_gauge.get_cell( 762 "num_workers").set(self.extended.num_hosts) 763 distribute_lib.distribution_strategy_replica_gauge.get_cell( 764 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 765 # Packed variable is used to reduce the overhead of function execution. 766 # For a DistributedVariable, only one variable handle is captured into a 767 # function graph. It's only supported in eager mode. 768 self._enable_packed_variable_in_eager_mode = True 769 770 @property 771 def steps_per_run(self): 772 """DEPRECATED: use .extended.steps_per_run instead.""" 773 return self._extended.steps_per_run 774 775 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 776 # can use the default implementation. 777 # This implementation runs a single step. It does not use infeed or outfeed. 778 def run(self, fn, args=(), kwargs=None, options=None): 779 """Run `fn` on each replica, with the given arguments. 780 781 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 782 "per-replica" values, such as those produced by a "distributed `Dataset`", 783 when `fn` is executed on a particular replica, it will be executed with the 784 component of those "per-replica" values that correspond to that replica. 785 786 `fn` may call `tf.distribute.get_replica_context()` to access members such 787 as `all_reduce`. 788 789 All arguments in `args` or `kwargs` should either be nest of tensors or 790 per-replica objects containing tensors or composite tensors. 791 792 Users can pass strategy specific options to `options` argument. An example 793 to enable bucketizing dynamic shapes in `TPUStrategy.run` 794 is: 795 796 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 797 >>> tf.config.experimental_connect_to_cluster(resolver) 798 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 799 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 800 801 >>> options = tf.distribute.RunOptions( 802 ... experimental_bucketizing_dynamic_shape=True) 803 804 >>> dataset = tf.data.Dataset.range( 805 ... strategy.num_replicas_in_sync, output_type=dtypes.float32).batch( 806 ... strategy.num_replicas_in_sync, drop_remainder=True) 807 >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 808 809 >>> @tf.function() 810 ... def step_fn(inputs): 811 ... output = tf.reduce_sum(inputs) 812 ... return output 813 814 >>> strategy.run(step_fn, args=(next(input_iterator),), options=options) 815 816 Args: 817 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 818 args: (Optional) Positional arguments to `fn`. 819 kwargs: (Optional) Keyword arguments to `fn`. 820 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 821 the options to run `fn`. 822 823 Returns: 824 Merged return value of `fn` across replicas. The structure of the return 825 value is the same as the return value from `fn`. Each element in the 826 structure can either be "per-replica" `Tensor` objects or `Tensor`s 827 (for example, if running on a single replica). 828 """ 829 validate_run_function(fn) 830 831 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 832 833 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 834 options = options or distribute_lib.RunOptions() 835 return self.extended.tpu_run(fn, args, kwargs, options) 836 837 838# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 839class TPUExtended(distribute_lib.StrategyExtendedV1): 840 """Implementation of TPUStrategy.""" 841 842 def __init__(self, 843 container_strategy, 844 tpu_cluster_resolver=None, 845 steps_per_run=None, 846 device_assignment=None, 847 use_spmd_for_xla_partitioning=False): 848 super(TPUExtended, self).__init__(container_strategy) 849 850 if tpu_cluster_resolver is None: 851 tpu_cluster_resolver = TPUClusterResolver("") 852 853 if steps_per_run is None: 854 # TODO(frankchn): Warn when we are being used by DS/Keras and this is 855 # not specified. 856 steps_per_run = 1 857 858 # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a 859 # `tf.function` is passed into `strategy.run` in eager mode, the 860 # `tf.function` won't get retraced. 861 self._tpu_function_cache = weakref.WeakKeyDictionary() 862 863 self._tpu_cluster_resolver = tpu_cluster_resolver 864 self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata() 865 self._device_assignment = device_assignment 866 867 tpu_devices_flat = [ 868 d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name] 869 870 # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is 871 # indexed using `[replica_id][logical_device_id]`. 872 if device_assignment is None: 873 self._tpu_devices = np.array( 874 [[d] for d in tpu_devices_flat], dtype=object) 875 else: 876 job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job 877 878 tpu_devices = [] 879 for replica_id in range(device_assignment.num_replicas): 880 replica_devices = [] 881 882 for logical_core in range(device_assignment.num_cores_per_replica): 883 replica_devices.append( 884 device_util.canonicalize( 885 device_assignment.tpu_device( 886 replica=replica_id, 887 logical_core=logical_core, 888 job=job_name))) 889 890 tpu_devices.append(replica_devices) 891 self._tpu_devices = np.array(tpu_devices, dtype=object) 892 893 self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0]) 894 895 # Preload the data onto the TPUs. Currently we always preload onto logical 896 # device 0 for each replica. 897 # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the 898 # input onto a different logical device? 899 self._device_input_worker_devices = collections.OrderedDict() 900 self._host_input_worker_devices = collections.OrderedDict() 901 for tpu_device in self._tpu_devices[:, 0]: 902 host_device = device_util.get_host_for_device(tpu_device) 903 self._device_input_worker_devices.setdefault(host_device, []) 904 self._device_input_worker_devices[host_device].append(tpu_device) 905 self._host_input_worker_devices.setdefault(host_device, []) 906 self._host_input_worker_devices[host_device].append(host_device) 907 908 # TODO(sourabhbajaj): Remove this once performance of running one step 909 # at a time is comparable to multiple steps. 910 self.steps_per_run = steps_per_run 911 self._require_static_shapes = True 912 913 self.experimental_enable_get_next_as_optional = True 914 915 self._logical_device_stack = [0] 916 917 if context.executing_eagerly(): 918 # In async remote eager, we want to sync the executors before exiting the 919 # program. 920 atexit.register(context.async_wait) 921 922 # Flag to turn on VariablePolicy. Var policy is deprecated because there is 923 # another effort unifying DistributedVariables (see values_v2.py). SPMD XLA 924 # partitioning is not implemented for var policies. 925 # TODO(b/202048882): remove var policy from TPUStrategy. 926 self._use_var_policy = not use_spmd_for_xla_partitioning 927 928 # Flag to enable XLA SPMD partitioning. 929 self._use_spmd_for_xla_partitioning = use_spmd_for_xla_partitioning 930 931 def _validate_colocate_with_variable(self, colocate_with_variable): 932 distribute_utils. validate_colocate(colocate_with_variable, self) 933 934 def _make_dataset_iterator(self, dataset): 935 """Make iterators for each of the TPU hosts.""" 936 input_workers = input_lib.InputWorkers( 937 tuple(self._device_input_worker_devices.items())) 938 return input_lib_v1.DatasetIterator( 939 dataset, 940 input_workers, 941 self._container_strategy(), 942 num_replicas_in_sync=self._num_replicas_in_sync) 943 944 def _make_input_fn_iterator( 945 self, 946 input_fn, 947 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 948 input_contexts = [] 949 input_workers = input_lib.InputWorkers( 950 tuple(self._device_input_worker_devices.items())) 951 num_workers = input_workers.num_workers 952 for i in range(num_workers): 953 input_contexts.append( 954 distribute_lib.InputContext( 955 num_input_pipelines=num_workers, 956 input_pipeline_id=i, 957 num_replicas_in_sync=self._num_replicas_in_sync)) 958 return input_lib_v1.InputFunctionIterator(input_fn, input_workers, 959 input_contexts, 960 self._container_strategy()) 961 962 def _experimental_make_numpy_dataset(self, numpy_input, session): 963 return numpy_dataset.one_host_numpy_dataset( 964 numpy_input, numpy_dataset.SingleDevice(self._host_device), 965 session) 966 967 def _get_input_workers(self, options): 968 if not options or options.experimental_fetch_to_device: 969 return input_lib.InputWorkers( 970 tuple(self._device_input_worker_devices.items())) 971 else: 972 return input_lib.InputWorkers( 973 tuple(self._host_input_worker_devices.items())) 974 975 def _check_spec(self, element_spec): 976 if isinstance(element_spec, values.PerReplicaSpec): 977 element_spec = element_spec._component_specs # pylint: disable=protected-access 978 specs = nest.flatten_with_joined_string_paths(element_spec) 979 for path, spec in specs: 980 if isinstance(spec, (sparse_tensor.SparseTensorSpec, 981 ragged_tensor.RaggedTensorSpec)): 982 raise ValueError( 983 "Found tensor {} with spec {}. TPUStrategy does not support " 984 "distributed datasets with device prefetch when using sparse or " 985 "ragged tensors. If you intend to use sparse or ragged tensors, " 986 "please pass a tf.distribute.InputOptions object with " 987 "experimental_fetch_to_device set to False to your dataset " 988 "distribution function.".format(path, type(spec))) 989 990 def _experimental_distribute_dataset(self, dataset, options): 991 if (options and options.experimental_replication_mode == 992 distribute_lib.InputReplicationMode.PER_REPLICA): 993 raise NotImplementedError( 994 "InputReplicationMode.PER_REPLICA " 995 "is only supported in " 996 "`experimental_distribute_datasets_from_function`." 997 ) 998 if options is None or options.experimental_fetch_to_device: 999 self._check_spec(dataset.element_spec) 1000 1001 return input_util.get_distributed_dataset( 1002 dataset, 1003 self._get_input_workers(options), 1004 self._container_strategy(), 1005 num_replicas_in_sync=self._num_replicas_in_sync, 1006 options=options) 1007 1008 def _distribute_datasets_from_function(self, dataset_fn, options): 1009 if (options and options.experimental_replication_mode == 1010 distribute_lib.InputReplicationMode.PER_REPLICA): 1011 raise NotImplementedError( 1012 "InputReplicationMode.PER_REPLICA " 1013 "is only supported in " 1014 " `experimental_distribute_datasets_from_function` " 1015 "of tf.distribute.MirroredStrategy") 1016 input_workers = self._get_input_workers(options) 1017 input_contexts = [] 1018 num_workers = input_workers.num_workers 1019 for i in range(num_workers): 1020 input_contexts.append(distribute_lib.InputContext( 1021 num_input_pipelines=num_workers, 1022 input_pipeline_id=i, 1023 num_replicas_in_sync=self._num_replicas_in_sync)) 1024 1025 distributed_dataset = input_util.get_distributed_datasets_from_function( 1026 dataset_fn, 1027 input_workers, 1028 input_contexts, 1029 self._container_strategy(), 1030 options=options) 1031 1032 # We can only check after the dataset_fn is called. 1033 if options is None or options.experimental_fetch_to_device: 1034 self._check_spec(distributed_dataset.element_spec) 1035 return distributed_dataset 1036 1037 def _experimental_distribute_values_from_function(self, value_fn): 1038 per_replica_values = [] 1039 for replica_id in range(self._num_replicas_in_sync): 1040 per_replica_values.append( 1041 value_fn(distribute_lib.ValueContext(replica_id, 1042 self._num_replicas_in_sync))) 1043 return distribute_utils.regroup(per_replica_values, always_wrap=True) 1044 1045 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 1046 # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have 1047 # a mechanism to infer the outputs of `fn`. Pending b/110550782. 1048 def _experimental_run_steps_on_iterator( 1049 self, fn, multi_worker_iterator, iterations, initial_loop_values=None): 1050 # Wrap `fn` for repeat. 1051 if initial_loop_values is None: 1052 initial_loop_values = {} 1053 initial_loop_values = nest.flatten(initial_loop_values) 1054 ctx = input_lib.MultiStepContext() 1055 1056 def run_fn(inputs): 1057 """Single step on the TPU device.""" 1058 fn_result = fn(ctx, inputs) 1059 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 1060 if flat_last_step_outputs: 1061 with ops.control_dependencies([fn_result]): 1062 return [array_ops.identity(f) for f in flat_last_step_outputs] 1063 else: 1064 return fn_result 1065 1066 # We capture the control_flow_context at this point, before we run `fn` 1067 # inside a while_loop and TPU replicate context. This is useful in cases 1068 # where we might need to exit these contexts and get back to the outer 1069 # context to do some things, for e.g. create an op which should be 1070 # evaluated only once at the end of the loop on the host. One such usage 1071 # is in creating metrics' value op. 1072 self._outer_control_flow_context = ( 1073 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 1074 1075 def rewrite_fn(*args): 1076 """The rewritten step fn running on TPU.""" 1077 del args 1078 1079 per_replica_inputs = multi_worker_iterator.get_next() 1080 replicate_inputs = [] 1081 for replica_id in range(self._num_replicas_in_sync): 1082 select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda 1083 replica_id, x) # pylint: disable=cell-var-from-loop 1084 replicate_inputs.append((nest.map_structure( 1085 select_replica, per_replica_inputs),)) 1086 1087 replicate_outputs = tpu.replicate( 1088 run_fn, 1089 replicate_inputs, 1090 device_assignment=self._device_assignment, 1091 xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self 1092 ._use_spmd_for_xla_partitioning)) 1093 # If run_fn has tensor outputs, tpu.replicate returns a list of list. We 1094 # will flatten it in this case. If run_fn has no tensor outputs, 1095 # tpu.replicate returns a list of no_ops, we will keep the output as it 1096 # is. 1097 if isinstance(replicate_outputs[0], list): 1098 replicate_outputs = nest.flatten(replicate_outputs) 1099 1100 return replicate_outputs 1101 1102 # TODO(sourabhbajaj): The input to while loop should be based on the 1103 # output type of the step_fn 1104 assert isinstance(initial_loop_values, list) 1105 initial_loop_values = initial_loop_values * self._num_replicas_in_sync 1106 1107 # Put the while loop op on TPU host 0. 1108 with ops.device(self._host_device): 1109 if self.steps_per_run == 1: 1110 replicate_outputs = rewrite_fn() 1111 else: 1112 replicate_outputs = training_loop.repeat(iterations, rewrite_fn, 1113 initial_loop_values) 1114 1115 del self._outer_control_flow_context 1116 ctx.run_op = control_flow_ops.group(replicate_outputs) 1117 1118 if isinstance(replicate_outputs, list): 1119 # Filter out any ops from the outputs, typically this would be the case 1120 # when there were no tensor outputs. 1121 last_step_tensor_outputs = [ 1122 x for x in replicate_outputs if not isinstance(x, ops.Operation) 1123 ] 1124 1125 # Outputs are currently of the structure (flattened) 1126 # [output0_device0, output1_device0, output2_device0, 1127 # output0_device1, output1_device1, output2_device1, 1128 # ...] 1129 # Convert this to the following structure instead: (grouped by output) 1130 # [[output0_device0, output0_device1], 1131 # [output1_device0, output1_device1], 1132 # [output2_device0, output2_device1]] 1133 output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync 1134 last_step_tensor_outputs = [ 1135 last_step_tensor_outputs[i::output_num] for i in range(output_num) 1136 ] 1137 else: 1138 # no tensors returned. 1139 last_step_tensor_outputs = [] 1140 1141 _set_last_step_outputs(ctx, last_step_tensor_outputs) 1142 return ctx 1143 1144 def _call_for_each_replica(self, fn, args, kwargs): 1145 # TODO(jhseu): Consider making it so call_for_each_replica implies that 1146 # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. 1147 with _TPUReplicaContext(self._container_strategy()): 1148 return fn(*args, **kwargs) 1149 1150 @contextlib.contextmanager 1151 def experimental_logical_device(self, logical_device_id): 1152 """Places variables and ops on the specified logical device.""" 1153 num_logical_devices_per_replica = self._tpu_devices.shape[1] 1154 if logical_device_id >= num_logical_devices_per_replica: 1155 raise ValueError( 1156 "`logical_device_id` not in range (was {}, but there are only {} " 1157 "logical devices per replica).".format( 1158 logical_device_id, num_logical_devices_per_replica)) 1159 1160 self._logical_device_stack.append(logical_device_id) 1161 try: 1162 if tpu_util.enclosing_tpu_context() is None: 1163 yield 1164 else: 1165 with ops.device(tpu.core(logical_device_id)): 1166 yield 1167 finally: 1168 self._logical_device_stack.pop() 1169 1170 def _experimental_initialize_system(self): 1171 """Experimental method added to be used by Estimator. 1172 1173 This is a private method only to be used by Estimator. Other frameworks 1174 should directly be calling `tf.tpu.experimental.initialize_tpu_system` 1175 """ 1176 tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver) 1177 1178 def _create_variable(self, next_creator, **kwargs): 1179 """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" 1180 if kwargs.pop("skip_mirrored_creator", False): 1181 return next_creator(**kwargs) 1182 1183 colocate_with = kwargs.pop("colocate_with", None) 1184 if colocate_with is None: 1185 devices = self._tpu_devices[:, self._logical_device_stack[-1]] 1186 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 1187 with ops.device(colocate_with.device): 1188 return next_creator(**kwargs) 1189 else: 1190 devices = colocate_with._devices # pylint: disable=protected-access 1191 1192 num_replicas, num_cores_per_replica = self._tpu_devices.shape 1193 1194 def _create_mirrored_tpu_variables(**kwargs): 1195 """Returns a list of `tf.Variable`s. 1196 1197 The list contains `number_replicas` `tf.Variable`s and can be used to 1198 initialize a `TPUMirroredVariable`. 1199 1200 Args: 1201 **kwargs: the keyword arguments for creating a variable 1202 """ 1203 initial_value = None 1204 value_list = [] 1205 for i, d in enumerate(devices): 1206 with ops.device(d): 1207 if i == 0: 1208 initial_value = kwargs["initial_value"] 1209 # Note: some v1 code expects variable initializer creation to happen 1210 # inside a init_scope. 1211 with maybe_init_scope(): 1212 initial_value = initial_value() if callable( 1213 initial_value) else initial_value 1214 1215 if i > 0: 1216 # Give replicas meaningful distinct names: 1217 var0name = value_list[0].name.split(":")[0] 1218 # We append a / to variable names created on replicas with id > 0 to 1219 # ensure that we ignore the name scope and instead use the given 1220 # name as the absolute name of the variable. 1221 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 1222 kwargs["initial_value"] = initial_value 1223 1224 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 1225 v = next_creator(**kwargs) 1226 1227 assert not isinstance(v, tpu_values.TPUMirroredVariable) 1228 value_list.append(v) 1229 return value_list 1230 1231 def _create_mirrored_tpu_replicated_variables(**kwargs): 1232 """Returns a list of `TPUReplicatedVariable`s. 1233 1234 The list consists of `num_replicas` `TPUReplicatedVariable`s and can be 1235 used to initialize a `TPUMirroredVariable`. Each `TPUReplicatedVariable` 1236 contains a list of `tf.Variable`s which are replicated to 1237 `num_cores_per_replica` logical cores to enable XLA SPMD compilation. 1238 1239 Args: 1240 **kwargs: the keyword arguments for creating a variable 1241 """ 1242 initial_value = kwargs["initial_value"] 1243 # Note: some v1 code expects variable initializer creation to happen 1244 # inside a init_scope. 1245 with maybe_init_scope(): 1246 initial_value = initial_value() if callable( 1247 initial_value) else initial_value 1248 1249 mirrored_replicated_var_list = [] 1250 1251 for replica_id in range(num_replicas): 1252 replicated_var_list = [] 1253 for logic_core_id in range(num_cores_per_replica): 1254 with ops.device(self._tpu_devices[replica_id][logic_core_id]): 1255 kwargs["initial_value"] = initial_value 1256 v = next_creator(**kwargs) 1257 replicated_var_list.append(v) 1258 replica_name = "{}/r:{}".format(kwargs["name"], replica_id) 1259 tpu_replicated_var = tpu_replicated_variable.TPUReplicatedVariable( 1260 variables=replicated_var_list, name=replica_name) 1261 1262 mirrored_replicated_var_list.append(tpu_replicated_var) 1263 return mirrored_replicated_var_list 1264 1265 if self._use_spmd_for_xla_partitioning and num_cores_per_replica > 1: 1266 real_creator = _create_mirrored_tpu_replicated_variables 1267 else: 1268 real_creator = _create_mirrored_tpu_variables 1269 1270 return distribute_utils.create_mirrored_variable( 1271 self._container_strategy(), real_creator, 1272 distribute_utils.TPU_VARIABLE_CLASS_MAPPING, 1273 distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs) 1274 1275 def _resource_creator_scope(self): 1276 1277 def lookup_creator(next_creator, *args, **kwargs): 1278 host_to_table = collections.OrderedDict() 1279 for host_device in self._device_input_worker_devices.keys(): 1280 with ops.device(host_device): 1281 host_to_table[host_device] = next_creator(*args, **kwargs) 1282 1283 return values.PerWorkerResource(self._container_strategy(), host_to_table) 1284 1285 # TODO(b/194362531): Define creator(s) for other resources. 1286 return ops.resource_creator_scope("StaticHashTable", lookup_creator) 1287 1288 def _gather_to_implementation(self, value, destinations, axis, options): 1289 if not isinstance(value, values.DistributedValues): 1290 return value 1291 1292 value_list = list(value.values) 1293 # pylint: disable=protected-access 1294 if isinstance( 1295 value, 1296 values.DistributedVariable) and value._packed_variable is not None: 1297 value_list = list( 1298 value._packed_variable.on_device(d) 1299 for d in value._packed_variable.devices) 1300 # pylint: enable=protected-access 1301 1302 # Currently XLA op by op mode has a limit for the number of inputs for a 1303 # single op, thus we break one `add_n` op into a group of `add_n` ops to 1304 # work around the constraint. 1305 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 1306 output = array_ops.concat(value_list, axis=axis) 1307 else: 1308 output = array_ops.concat( 1309 value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis) 1310 for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list), 1311 _XLA_OP_BY_OP_INPUTS_LIMIT - 1): 1312 output = array_ops.concat( 1313 [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1], 1314 axis=axis) 1315 1316 output = self._broadcast_output(destinations, output) 1317 return output 1318 1319 def _broadcast_output(self, destinations, output): 1320 devices = cross_device_ops_lib.get_devices_from(destinations) 1321 1322 if len(devices) == 1: 1323 # If necessary, copy to requested destination. 1324 dest_canonical = device_util.canonicalize(devices[0]) 1325 host_canonical = device_util.canonicalize(self._host_device) 1326 1327 if dest_canonical != host_canonical: 1328 with ops.device(dest_canonical): 1329 output = array_ops.identity(output) 1330 else: 1331 output = cross_device_ops_lib.simple_broadcast(output, destinations) 1332 1333 return output 1334 1335 def _reduce_to(self, reduce_op, value, destinations, options): 1336 if (isinstance(value, values.DistributedValues) or 1337 tensor_util.is_tf_type(value) 1338 ) and tpu_util.enclosing_tpu_context() is not None: 1339 if reduce_op == reduce_util.ReduceOp.MEAN: 1340 # TODO(jhseu): Revisit once we support model-parallelism. 1341 # scalar_mul maintains the type of value: tensor or IndexedSlices. 1342 value = math_ops.scalar_mul((1./self._num_replicas_in_sync), value) 1343 elif reduce_op != reduce_util.ReduceOp.SUM: 1344 raise NotImplementedError( 1345 f"`reduce_op`={reduce_op} is not supported. Currently we only " 1346 "support ReduceOp.SUM and ReduceOp.MEAN in TPUStrategy.") 1347 return tpu_ops.cross_replica_sum(value) 1348 1349 if not isinstance(value, values.DistributedValues): 1350 # This function handles reducing values that are not PerReplica or 1351 # Mirrored values. For example, the same value could be present on all 1352 # replicas in which case `value` would be a single value or value could 1353 # be 0. 1354 return cross_device_ops_lib.reduce_non_distributed_value( 1355 reduce_op, value, destinations, self._num_replicas_in_sync) 1356 1357 value_list = value.values 1358 # pylint: disable=protected-access 1359 if isinstance( 1360 value, 1361 values.DistributedVariable) and value._packed_variable is not None: 1362 value_list = tuple( 1363 value._packed_variable.on_device(d) 1364 for d in value._packed_variable.devices) 1365 # pylint: enable=protected-access 1366 1367 # Currently XLA op by op mode has a limit for the number of inputs for a 1368 # single op, thus we break one `add_n` op into a group of `add_n` ops to 1369 # work around the constraint. 1370 # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. 1371 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 1372 output = math_ops.add_n(value_list) 1373 else: 1374 output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype) 1375 for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT): 1376 output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT]) 1377 1378 if reduce_op == reduce_util.ReduceOp.MEAN: 1379 output *= (1. / len(value_list)) 1380 1381 output = self._broadcast_output(destinations, output) 1382 return output 1383 1384 def _update(self, var, fn, args, kwargs, group): 1385 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 1386 var, resource_variable_ops.BaseResourceVariable) 1387 if tpu_util.enclosing_tpu_context() is not None: 1388 if group: 1389 return fn(var, *args, **kwargs) 1390 else: 1391 return (fn(var, *args, **kwargs),) 1392 1393 # Inside `tf.function`, we don't expand PackedVariable in python as it will 1394 # be expanded later during function instantiation in the runtime. 1395 packed_var = var._packed_variable # pylint: disable=protected-access 1396 if packed_var is not None and not context.executing_eagerly(): 1397 if group: 1398 return fn(packed_var, *args, **kwargs) 1399 else: 1400 return (fn(packed_var, *args, **kwargs),) 1401 1402 # Otherwise, we revert to MirroredStrategy behavior and update the variable 1403 # on each replica directly. 1404 updates = [] 1405 values_and_devices = [] 1406 if packed_var is not None: 1407 for device in packed_var.devices: 1408 values_and_devices.append((packed_var, device)) 1409 else: 1410 for value in var.values: 1411 values_and_devices.append((value, value.device)) 1412 1413 if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and 1414 var.aggregation != variables_lib.VariableAggregation.NONE): 1415 distribute_utils.assert_mirrored(args) 1416 distribute_utils.assert_mirrored(kwargs) 1417 for i, value_and_device in enumerate(values_and_devices): 1418 value = value_and_device[0] 1419 device = value_and_device[1] 1420 name = "update_%d" % i 1421 with ops.device(device), \ 1422 distribute_lib.UpdateContext(i), \ 1423 ops.name_scope(name): 1424 # If args and kwargs are not mirrored, the value is returned as is. 1425 updates.append( 1426 fn(value, *distribute_utils.select_replica(i, args), 1427 **distribute_utils.select_replica(i, kwargs))) 1428 return distribute_utils.update_regroup(self, updates, group) 1429 1430 def read_var(self, var): 1431 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 1432 var, resource_variable_ops.BaseResourceVariable) 1433 return var.read_value() 1434 1435 def value_container(self, value): 1436 return value 1437 1438 def _broadcast_to(self, tensor, destinations): 1439 del destinations 1440 # This is both a fast path for Python constants, and a way to delay 1441 # converting Python values to a tensor until we know what type it 1442 # should be converted to. Otherwise we have trouble with: 1443 # global_step.assign_add(1) 1444 # since the `1` gets broadcast as an int32 but global_step is int64. 1445 if isinstance(tensor, (float, int)): 1446 return tensor 1447 if tpu_util.enclosing_tpu_context() is not None: 1448 broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] 1449 result = tpu_ops.all_to_all( 1450 broadcast_tensor, 1451 concat_dimension=0, 1452 split_dimension=0, 1453 split_count=self._num_replicas_in_sync) 1454 1455 # This uses the broadcasted value from the first replica because the only 1456 # caller of this is for ONLY_FIRST_REPLICA variables aggregation. 1457 return result[0] 1458 return tensor 1459 1460 @property 1461 def num_hosts(self): 1462 if self._device_assignment is None: 1463 return self._tpu_metadata.num_hosts 1464 1465 return len(set([self._device_assignment.host_device(r) 1466 for r in range(self._device_assignment.num_replicas)])) 1467 1468 @property 1469 def num_replicas_per_host(self): 1470 if self._device_assignment is None: 1471 return self._tpu_metadata.num_of_cores_per_host 1472 1473 # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed 1474 # as the computation of num_replicas_per_host is not a constant 1475 # when using device_assignment. This is a temporary workaround to support 1476 # StatefulRNN as everything is 1 in that case. 1477 # This method needs to take host_id as input for correct computation. 1478 max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // 1479 self._device_assignment.num_cores_per_replica) 1480 return min(self._device_assignment.num_replicas, max_models_per_host) 1481 1482 @property 1483 def _num_replicas_in_sync(self): 1484 if self._device_assignment is None: 1485 return self._tpu_metadata.num_cores 1486 return self._device_assignment.num_replicas 1487 1488 @property 1489 def experimental_between_graph(self): 1490 return False 1491 1492 @property 1493 def experimental_should_init(self): 1494 return True 1495 1496 @property 1497 def should_checkpoint(self): 1498 return True 1499 1500 @property 1501 def should_save_summary(self): 1502 return True 1503 1504 @property 1505 def worker_devices(self): 1506 return tuple(self._tpu_devices[:, self._logical_device_stack[-1]]) 1507 1508 @property 1509 def parameter_devices(self): 1510 return self.worker_devices 1511 1512 @property 1513 def tpu_hardware_feature(self): 1514 """Return the `tf.tpu.experimental.HardwareFeature` class.""" 1515 return tpu_hardware_feature.HardwareFeature( 1516 self._tpu_cluster_resolver.tpu_hardware_feature) 1517 1518 def non_slot_devices(self, var_list): 1519 return self._host_device 1520 1521 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 1522 del colocate_with 1523 with ops.device(self._host_device), distribute_lib.UpdateContext(None): 1524 result = fn(*args, **kwargs) 1525 if group: 1526 return result 1527 else: 1528 return nest.map_structure(self._local_results, result) 1529 1530 def _configure(self, 1531 session_config=None, 1532 cluster_spec=None, 1533 task_type=None, 1534 task_id=None): 1535 del cluster_spec, task_type, task_id 1536 if session_config: 1537 session_config.CopyFrom(self._update_config_proto(session_config)) 1538 1539 def _update_config_proto(self, config_proto): 1540 updated_config = copy.deepcopy(config_proto) 1541 updated_config.isolate_session_state = True 1542 cluster_spec = self._tpu_cluster_resolver.cluster_spec() 1543 if cluster_spec: 1544 updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 1545 return updated_config 1546 1547 # TODO(priyag): Delete this once all strategies use global batch size. 1548 @property 1549 def _global_batch_size(self): 1550 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 1551 1552 `make_input_fn_iterator` assumes per-replica batching. 1553 1554 Returns: 1555 Boolean. 1556 """ 1557 return True 1558 1559 def tpu_run(self, fn, args, kwargs, options=None): 1560 func = self._tpu_function_creator(fn, options) 1561 return func(args, kwargs) 1562 1563 def _tpu_function_creator(self, fn, options): 1564 if context.executing_eagerly() and fn in self._tpu_function_cache: 1565 return self._tpu_function_cache[fn] 1566 1567 strategy = self._container_strategy() 1568 1569 def tpu_function(args, kwargs): 1570 """TF Function used to replicate the user computation.""" 1571 logging.vlog(1, 1572 "`TPUStrategy.run` is called with [args: %s] [kwargs: %s]", 1573 args, kwargs) 1574 1575 if kwargs is None: 1576 kwargs = {} 1577 1578 # Used to re-structure flattened output tensors from `tpu.replicate()` 1579 # into a structured format. 1580 result = [[]] 1581 1582 def replicated_fn(replica_id, replica_args, replica_kwargs): 1583 """Wraps user function to provide replica ID and `Tensor` inputs.""" 1584 with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): 1585 result[0] = fn(*replica_args, **replica_kwargs) 1586 return result[0] 1587 1588 replicate_inputs = [] # By replica. 1589 for i in range(strategy.num_replicas_in_sync): 1590 replicate_inputs.append( 1591 [constant_op.constant(i, dtype=dtypes.int32), 1592 distribute_utils.select_replica(i, args), 1593 distribute_utils.select_replica(i, kwargs)]) 1594 1595 # Construct and pass `maximum_shapes` so that we could support dynamic 1596 # shapes using dynamic padder. 1597 if options.experimental_enable_dynamic_batch_size and replicate_inputs: 1598 maximum_shapes = [] 1599 flattened_list = nest.flatten(replicate_inputs[0]) 1600 for input_tensor in flattened_list: 1601 if tensor_util.is_tf_type(input_tensor): 1602 rank = input_tensor.shape.rank 1603 else: 1604 rank = np.ndim(input_tensor) 1605 if rank is None: 1606 raise ValueError( 1607 "input tensor {} to TPUStrategy.run() has unknown rank, " 1608 "which is not allowed".format(input_tensor)) 1609 maximum_shape = tensor_shape.TensorShape([None] * rank) 1610 maximum_shapes.append(maximum_shape) 1611 maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], 1612 maximum_shapes) 1613 else: 1614 maximum_shapes = None 1615 1616 if options.experimental_bucketizing_dynamic_shape: 1617 padding_spec = tpu.PaddingSpec.POWER_OF_TWO 1618 else: 1619 padding_spec = None 1620 1621 with strategy.scope(): 1622 xla_options = options.experimental_xla_options or tpu.XLAOptions( 1623 use_spmd_for_xla_partitioning=self._use_spmd_for_xla_partitioning) 1624 replicate_outputs = tpu.replicate( 1625 replicated_fn, 1626 replicate_inputs, 1627 device_assignment=self._device_assignment, 1628 maximum_shapes=maximum_shapes, 1629 padding_spec=padding_spec, 1630 xla_options=xla_options) 1631 1632 # Remove all no ops that may have been added during 'tpu.replicate()' 1633 filter_ops = lambda x: [o for o in x if not isinstance(o, ops.Operation)] 1634 if isinstance(result[0], list): 1635 result[0] = filter_ops(result[0]) 1636 1637 # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. 1638 if result[0] is None or isinstance(result[0], ops.Operation): 1639 replicate_outputs = [None] * len(replicate_outputs) 1640 else: 1641 replicate_outputs = [ 1642 nest.pack_sequence_as(result[0], filter_ops(nest.flatten(output))) 1643 for output in replicate_outputs 1644 ] 1645 return distribute_utils.regroup(replicate_outputs) 1646 1647 if context.executing_eagerly(): 1648 tpu_function = def_function.function(tpu_function) 1649 self._tpu_function_cache[fn] = tpu_function 1650 return tpu_function 1651 1652 def _in_multi_worker_mode(self): 1653 """Whether this strategy indicates working in multi-worker settings.""" 1654 # TPUStrategy has different distributed training structure that the whole 1655 # cluster should be treated as single worker from higher-level (e.g. Keras) 1656 # library's point of view. 1657 # TODO(rchao): Revisit this as we design a fault-tolerance solution for 1658 # TPUStrategy. 1659 return False 1660 1661 def _get_local_replica_id(self, replica_id_in_sync_group): 1662 return replica_id_in_sync_group 1663 1664 1665def _make_axis_nonnegative(axis, rank): 1666 # Convert a potentially negative `axis` to a non-negative one. 1667 if isinstance(axis, int): 1668 if axis >= 0: 1669 return axis 1670 else: 1671 return axis + rank 1672 else: 1673 return array_ops.where_v2( 1674 math_ops.greater_equal(axis, 0), 1675 axis, 1676 axis + rank) 1677 1678 1679# List of Tensor dtypes supported by cross_replica_sum(). 1680_DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM = ( 1681 dtypes.bfloat16, 1682 dtypes.float16, 1683 dtypes.float32, 1684 dtypes.float64, 1685 dtypes.int32, 1686 dtypes.uint32, 1687) 1688 1689 1690class _TPUReplicaContext(distribute_lib.ReplicaContext): 1691 """Replication Context class for TPU Strategy.""" 1692 1693 # TODO(sourabhbajaj): Call for each replica should be updating this. 1694 # TODO(b/118385803): Always properly initialize replica_id. 1695 def __init__(self, strategy, replica_id_in_sync_group=0): 1696 distribute_lib.ReplicaContext.__init__( 1697 self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) 1698 1699 @property 1700 def devices(self): 1701 distribute_lib.require_replica_context(self) 1702 ds = self._strategy 1703 replica_id = tensor_util.constant_value(self.replica_id_in_sync_group) 1704 1705 if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. 1706 # TODO(cjfj): Return other devices when model parallelism is supported. 1707 return (tpu.core(0),) 1708 else: 1709 return (ds.extended.worker_devices[replica_id],) 1710 1711 def experimental_logical_device(self, logical_device_id): 1712 """Places variables and ops on the specified logical device.""" 1713 return self.strategy.extended.experimental_logical_device(logical_device_id) 1714 1715 def _compute_all_gather_output_shape(self, value_shape, value_rank, axis): 1716 if isinstance(value_rank, int): 1717 output_shape = list(value_shape) 1718 output_shape[axis] *= self.num_replicas_in_sync 1719 else: 1720 output_shape = array_ops.where_v2( 1721 math_ops.equal(math_ops.range(value_rank), axis), 1722 value_shape * context.num_replicas_in_sync, 1723 value_shape) 1724 return output_shape 1725 1726 def all_gather(self, value, axis, experimental_hints=None): 1727 del experimental_hints 1728 for v in nest.flatten(value): 1729 if isinstance(v, indexed_slices.IndexedSlices): 1730 raise NotImplementedError("all_gather does not support IndexedSlices") 1731 1732 def _all_gather_tensor(value, axis): 1733 value = ops.convert_to_tensor(value) 1734 1735 # Compute the shape and rank and rank of the input tensor. Use static 1736 # shapes when possible to help with shape inference in graph mode, but 1737 # fall back on dynamic shapes when necessary. 1738 if value.shape.rank is None: 1739 value_rank = array_ops.rank(value) 1740 value_shape = array_ops.shape(value) 1741 else: 1742 value_rank = value.shape.rank 1743 value_shape = value.shape.as_list() 1744 value_shape_tensor = array_ops.shape(value) 1745 for i in range(len(value_shape)): 1746 if value_shape[i] is None: 1747 value_shape[i] = value_shape_tensor[i] 1748 1749 # In the code below, we will insert a new "replica" dimension immediately 1750 # *before* `axis`. To ensure that it's inserted before and not after, we 1751 # must make `axis` non-negative. 1752 axis = _make_axis_nonnegative(axis, value_rank) 1753 1754 # Create a list or 1D int Tensor such as 1755 # [1, 1, ..., 1, num_replicas_in_sync, 1, ..., 1], 1756 # which is equal to `num_replicas_in_sync` at index `axis` 1757 # and is equal to 1 everywhere else. 1758 if isinstance(value_rank, int): 1759 replica_broadcast_shape = [1] * (value_rank + 1) 1760 replica_broadcast_shape[axis] = self.num_replicas_in_sync 1761 else: 1762 replica_broadcast_shape = array_ops.where_v2( 1763 math_ops.equal(math_ops.range(value_rank+1), axis), 1764 self.num_replicas_in_sync, 1765 1) 1766 1767 output_shape = self._compute_all_gather_output_shape( 1768 value_shape, value_rank, axis) 1769 1770 if value.dtype in _DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM: 1771 # optimized all_gather implementation based on cross_replica_sum(). 1772 replica_id_mask = array_ops.one_hot( 1773 self.replica_id_in_sync_group, self.num_replicas_in_sync) 1774 replica_id_mask = array_ops.reshape( 1775 replica_id_mask, replica_broadcast_shape) 1776 replica_id_mask = math_ops.cast(replica_id_mask, value.dtype) 1777 1778 gathered_value = array_ops.expand_dims(value, axis) * replica_id_mask 1779 gathered_value = self.all_reduce( 1780 reduce_util.ReduceOp.SUM, gathered_value) 1781 return array_ops.reshape(gathered_value, output_shape) 1782 else: 1783 # value.dtype isn't supported by cross_replica_sum(), so we fall back 1784 # on a less efficient implementation based on all_to_all(). 1785 1786 # The underlying AllToAllOp first do a split of the input value and then 1787 # cross-replica communication and concatenation of the result. So we 1788 # concatenate the local tensor here first. 1789 inputs = array_ops.expand_dims(value, axis=axis) 1790 inputs = array_ops.tile(inputs, replica_broadcast_shape) 1791 unordered_output = tpu_ops.all_to_all( 1792 inputs, 1793 concat_dimension=axis, 1794 split_dimension=axis, 1795 split_count=self.num_replicas_in_sync) 1796 1797 # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch. 1798 # Start by computing a permutation -- a 1D Tensor which maps 1799 # tensor[xla.replica_id] = ReplicaContext.replica_id 1800 concat_replica_id = array_ops.reshape( 1801 self.replica_id_in_sync_group, [1]) 1802 concat_replica_id = array_ops.tile( 1803 concat_replica_id, [self.num_replicas_in_sync]) 1804 xla_to_replica_context_id = tpu_ops.all_to_all( 1805 concat_replica_id, 1806 concat_dimension=0, 1807 split_dimension=0, 1808 split_count=self.num_replicas_in_sync) 1809 1810 # Now invert the mapping to get 1811 # tensor[ReplicaContext.replica_id] = xla.replica_id 1812 replica_context_to_xla_id = math_ops.argmax( 1813 array_ops.one_hot(xla_to_replica_context_id, 1814 self.num_replicas_in_sync), 1815 axis=0) 1816 1817 # Reorder the output elements so that they're sorted based on 1818 # ReplicaContext.replica_id instead of xla.replica_id. 1819 sorted_with_extra_dim = array_ops.gather( 1820 unordered_output, replica_context_to_xla_id, axis=axis) 1821 return array_ops.reshape(sorted_with_extra_dim, output_shape) 1822 1823 ys = [_all_gather_tensor(t, axis=axis) for t in nest.flatten(value)] 1824 return nest.pack_sequence_as(value, ys) 1825 1826 1827def _set_last_step_outputs(ctx, last_step_tensor_outputs): 1828 """Sets the last step outputs on the given context.""" 1829 # Convert replicate_outputs to the original dict structure of 1830 # last_step_outputs. 1831 last_step_tensor_outputs_dict = nest.pack_sequence_as( 1832 ctx.last_step_outputs, last_step_tensor_outputs) 1833 1834 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 1835 output = last_step_tensor_outputs_dict[name] 1836 # For outputs that aren't reduced, return a PerReplica of all values. Else 1837 # take the first value from the list as each value should be the same. 1838 if reduce_op is None: 1839 last_step_tensor_outputs_dict[name] = values.PerReplica(output) 1840 else: 1841 # TODO(priyag): Should this return the element or a list with 1 element 1842 last_step_tensor_outputs_dict[name] = output[0] 1843 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 1844