1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Parameter server strategy V2 class. 16 17This is currently under development and the API is subject to change. 18""" 19 20import functools 21import os 22import threading 23 24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import input_lib 28from tensorflow.python.distribute import input_util 29from tensorflow.python.distribute import mirrored_run 30from tensorflow.python.distribute import multi_worker_util 31from tensorflow.python.distribute import parameter_server_strategy 32from tensorflow.python.distribute import ps_values 33from tensorflow.python.distribute import sharded_variable 34from tensorflow.python.distribute import values 35from tensorflow.python.eager import remote 36from tensorflow.python.framework import config 37from tensorflow.python.framework import device as tf_device 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import variable_scope as vs 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.trackable import base as trackable 45from tensorflow.python.training import server_lib 46from tensorflow.python.util import nest 47from tensorflow.python.util import tf_inspect 48from tensorflow.python.util.lazy_loader import LazyLoader 49from tensorflow.python.util.tf_export import tf_export 50 51ALLOWED_TASK_TYPES = ("chief", "worker", "ps") 52 53cluster_coordinator = LazyLoader( 54 "cluster_coordinator", globals(), 55 "tensorflow.python.distribute.coordinator.cluster_coordinator" 56) 57 58load_context = LazyLoader( 59 "load_context", globals(), 60 "tensorflow.python.keras.saving.saved_model.load_context" 61) 62 63 64@tf_export( 65 "distribute.experimental.ParameterServerStrategy", 66 "distribute.ParameterServerStrategy", 67 v1=[]) 68class ParameterServerStrategyV2(distribute_lib.Strategy): 69 """An multi-worker tf.distribute strategy with parameter servers. 70 71 Parameter server training is a common data-parallel method to scale up a 72 machine learning model on multiple machines. A parameter server training 73 cluster consists of workers and parameter servers. Variables are created on 74 parameter servers and they are read and updated by workers in each step. 75 By default, workers read and update these variables independently without 76 synchronizing with each other. Under this configuration, it is known as 77 asynchronous training. 78 79 In TensorFlow 2, we recommend an architecture based on central coordination 80 for parameter server training. Each worker and parameter server runs a 81 `tf.distribute.Server`, and on top of that, a coordinator task is responsible 82 for creating resources on workers and parameter servers, dispatching 83 functions, and coordinating the training. The coordinator uses a 84 `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the 85 cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define 86 variables on parameter servers and computation on workers. 87 88 For the training to work, the coordinator dispatches `tf.function`s to be 89 executed on remote workers. Upon receiving requests from the coordinator, a 90 worker executes the `tf.function` by reading the variables from parameter 91 servers, executing the ops, and updating the variables on the parameter 92 servers. Each of the worker only processes the requests from the coordinator, 93 and communicates with parameter servers, without direct interactions with 94 other workers in the cluster. 95 96 As a result, failures of some workers do not prevent the cluster from 97 continuing the work, and this allows the cluster to train with instances that 98 can be occasionally unavailable (e.g. preemptible or spot instances). The 99 coordinator and parameter servers though, must be available at all times for 100 the cluster to make progress. 101 102 Note that the coordinator is not one of the training workers. Instead, it 103 creates resources such as variables and datasets, dispatches `tf.function`s, 104 saves checkpoints and so on. In addition to workers, parameter servers and 105 the coordinator, an optional evaluator can be run on the side that 106 periodically reads the checkpoints saved by the coordinator and runs 107 evaluations against each checkpoint. 108 109 `ParameterServerStrategy` is supported with two training APIs: [Custom 110 Training Loop (CTL)] 111 (https://www.tensorflow.org/tutorials/distribute/custom_training) 112 and [Keras Training API, also known as `Model.fit`] 113 (https://www.tensorflow.org/tutorials/distribute/keras). CTL is recommended 114 when users prefer to define the details of their training loop, and 115 `Model.fit` is recommended when users prefer a high-level abstraction and 116 handling of training. 117 118 When using a CTL, `ParameterServerStrategy` has to work in conjunction with a 119 `tf.distribute.experimental.coordinator.ClusterCoordinator` object. 120 121 When using `Model.fit`, currently only the 122 `tf.keras.utils.experimental.DatasetCreator` input type is supported. 123 124 __Example code for coordinator__ 125 126 This section provides code snippets that are intended to be run on (the only) 127 one task that is designated as the coordinator. Note that `cluster_resolver`, 128 `variable_partitioner`, and `dataset_fn` arguments are explained in the 129 following "Cluster setup", "Variable partitioning", and "Dataset preparation" 130 sections. 131 132 With a CTL, 133 134 ```python 135 # Prepare a strategy to use with the cluster and variable partitioning info. 136 strategy = tf.distribute.experimental.ParameterServerStrategy( 137 cluster_resolver=..., 138 variable_partitioner=...) 139 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 140 strategy=strategy) 141 142 # Prepare a distribute dataset that will place datasets on the workers. 143 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...) 144 145 with strategy.scope(): 146 model = ... 147 optimizer, metrics = ... # Keras optimizer/metrics are great choices 148 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) 149 checkpoint_manager = tf.train.CheckpointManager( 150 checkpoint, checkpoint_dir, max_to_keep=2) 151 # `load_checkpoint` infers initial epoch from `optimizer.iterations`. 152 initial_epoch = load_checkpoint(checkpoint_manager) or 0 153 154 @tf.function 155 def worker_fn(iterator): 156 157 def replica_fn(inputs): 158 batch_data, labels = inputs 159 # calculate gradient, applying gradient, metrics update etc. 160 161 strategy.run(replica_fn, args=(next(iterator),)) 162 163 for epoch in range(initial_epoch, num_epoch): 164 distributed_iterator = iter(distributed_dataset) # Reset iterator state. 165 for step in range(steps_per_epoch): 166 167 # Asynchronously schedule the `worker_fn` to be executed on an arbitrary 168 # worker. This call returns immediately. 169 coordinator.schedule(worker_fn, args=(distributed_iterator,)) 170 171 # `join` blocks until all scheduled `worker_fn`s finish execution. Once it 172 # returns, we can read the metrics and save checkpoints as needed. 173 coordinator.join() 174 logging.info('Metric result: %r', metrics.result()) 175 train_accuracy.reset_states() 176 checkpoint_manager.save() 177 ``` 178 179 With `Model.fit`, 180 181 ```python 182 # Prepare a strategy to use with the cluster and variable partitioning info. 183 strategy = tf.distribute.experimental.ParameterServerStrategy( 184 cluster_resolver=..., 185 variable_partitioner=...) 186 187 # A dataset function takes a `input_context` and returns a `Dataset` 188 def dataset_fn(input_context): 189 dataset = tf.data.Dataset.from_tensors(...) 190 return dataset.repeat().shard(...).batch(...).prefetch(...) 191 192 # With `Model.fit`, a `DatasetCreator` needs to be used. 193 input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...) 194 195 with strategy.scope(): 196 model = ... # Make sure the `Model` is created within scope. 197 model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...) 198 199 # Optional callbacks to checkpoint the model, back up the progress, etc. 200 callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...] 201 202 # `steps_per_epoch` is required with `ParameterServerStrategy`. 203 model.fit(input, epochs=..., steps_per_epoch=..., callbacks=callbacks) 204 ``` 205 206 __Example code for worker and parameter servers__ 207 208 In addition to the coordinator, there should be tasks designated as 209 "worker" or "ps". They should run the following code to start a TensorFlow 210 server, waiting for coordinator's requests: 211 212 ```python 213 # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves 214 # the cluster information. See below "Cluster setup" section. 215 cluster_resolver = ... 216 217 server = tf.distribute.Server( 218 cluster_resolver.cluster_spec(), 219 job_name=cluster_resolver.task_type, 220 task_index=cluster_resolver.task_id, 221 protocol="grpc") 222 223 # Blocking the process that starts a server from exiting. 224 server.join() 225 ``` 226 227 __Cluster setup__ 228 229 In order for the tasks in the cluster to know other tasks' addresses, 230 a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used 231 in coordinator, worker, and ps. The 232 `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing 233 the cluster information, as well as the task type and id of the current task. 234 See `tf.distribute.cluster_resolver.ClusterResolver` for more information. 235 236 If `TF_CONFIG` environment variable is set, a 237 `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as 238 well. 239 240 Since there are assumptions in 241 `tf.distribute.experimental.ParameterServerStrategy` around the naming of the 242 task types, "chief", "ps", and "worker" should be used in the 243 `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator, 244 parameter servers, and workers, respectively. 245 246 The following example demonstrates setting `TF_CONFIG` for the task designated 247 as a parameter server (task type "ps") and index 1 (the second task), in a 248 cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs 249 to be set before the use of 250 `tf.distribute.cluster_resolver.TFConfigClusterResolver`. 251 252 Example code for cluster setup: 253 ```python 254 os.environ['TF_CONFIG'] = ''' 255 { 256 "cluster": { 257 "chief": ["chief.example.com:2222"], 258 "ps": ["ps0.example.com:2222", "ps1.example.com:2222"], 259 "worker": ["worker0.example.com:2222", "worker1.example.com:2222", 260 "worker2.example.com:2222"] 261 }, 262 "task": { 263 "type": "ps", 264 "index": 1 265 } 266 } 267 ''' 268 ``` 269 270 If you prefer to run the same binary for all tasks, you will need to let the 271 binary branch into different roles at the beginning of the program: 272 ```python 273 # If coordinator, create a strategy and start the training program. 274 if cluster_resolver.task_type == 'chief': 275 strategy = tf.distribute.experimental.ParameterServerStrategy( 276 cluster_resolver) 277 ... 278 279 # If worker/ps, create a server 280 elif cluster_resolver.task_type in ("worker", "ps"): 281 server = tf.distribute.Server(...) 282 ... 283 ``` 284 Alternatively, you can also start a bunch of TensorFlow servers in advance and 285 connect to them later. The coordinator can be in the same cluster or on any 286 machine that has connectivity to workers and parameter servers. This is 287 covered in our guide and tutorial. 288 289 __Variable creation with `strategy.scope()`__ 290 291 `tf.distribute.experimental.ParameterServerStrategy` follows the 292 `tf.distribute` API contract where variable creation is expected to be inside 293 the context manager returned by `strategy.scope()`, in order to be correctly 294 placed on parameter servers in a round-robin manner: 295 296 ```python 297 # In this example, we're assuming having 3 ps. 298 strategy = tf.distribute.experimental.ParameterServerStrategy( 299 cluster_resolver=...) 300 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 301 strategy=strategy) 302 303 # Variables should be created inside scope to be placed on parameter servers. 304 # If created outside scope such as `v1` here, it would be placed on the 305 # coordinator. 306 v1 = tf.Variable(initial_value=0.0) 307 308 with strategy.scope(): 309 v2 = tf.Variable(initial_value=1.0) 310 v3 = tf.Variable(initial_value=2.0) 311 v4 = tf.Variable(initial_value=3.0) 312 v5 = tf.Variable(initial_value=4.0) 313 314 # v2 through v5 are created in scope and are distributed on parameter servers. 315 # Default placement is round-robin but the order should not be relied on. 316 assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0" 317 assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0" 318 assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0" 319 assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0" 320 ``` 321 322 See `distribute.Strategy.scope` for more information. 323 324 __Variable partitioning__ 325 326 Having dedicated servers to store variables means being able to divide up, or 327 "shard" the variables across the ps. Partitioning large variable among ps is a 328 commonly used technique to boost training throughput and mitigate memory 329 constraints. It enables parallel computations and updates on different shards 330 of a variable, and often yields better load balancing across parameter 331 servers. Without sharding, models with large variables (e.g, embeddings) that 332 can't fit into one machine's memory would otherwise be unable to train. 333 334 With `tf.distribute.experimental.ParameterServerStrategy`, if a 335 `variable_partitioner` is provided to `__init__` and certain conditions are 336 satisfied, the resulting variables created in scope are sharded across the 337 parameter servers, in a round-robin fashion. The variable reference returned 338 from `tf.Variable` becomes a type that serves as the container of the sharded 339 variables. One can access `variables` attribute of this container for the 340 actual variable components. If building model with `tf.Module` or Keras, 341 the variable components are collected in the `variables` alike attributes. 342 343 It is recommended to use size-based partitioners like 344 `tf.distribute.experimental.partitioners.MinSizePartitioner` to avoid 345 partitioning small variables, which could have negative impact on model 346 training speed. 347 348 ```python 349 # Partition the embedding layer into 2 shards. 350 variable_partitioner = ( 351 tf.distribute.experimental.partitioners.MinSizePartitioner( 352 min_shard_bytes=(256 << 10), 353 max_shards = 2)) 354 strategy = tf.distribute.experimental.ParameterServerStrategy( 355 cluster_resolver=..., 356 variable_partitioner = variable_partitioner) 357 with strategy.scope(): 358 embedding = tf.keras.layers.Embedding(input_dim=1024, output_dim=1024) 359 assert len(embedding.variables) == 2 360 assert isinstance(embedding.variables[0], tf.Variable) 361 assert isinstance(embedding.variables[1], tf.Variable) 362 assert embedding.variables[0].shape == (512, 1024) 363 assert embedding.variables[1].shape == (512, 1024) 364 ``` 365 366 The sharded variable container can be converted to a `Tensor` via 367 `tf.convert_to_tensor`. This means the container can be directly used in most 368 Python Ops where such `Tensor` conversion automatically happens. For example, 369 in the above code snippet, `x * self.w` would implicitly apply the said tensor 370 conversion. Note that such conversion can be expensive, as the variable 371 components need to be transferred from multiple parameter servers to where 372 the value is used. 373 374 `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor 375 conversion, and performs parallel lookups on the variable components instead. 376 This is crucial to scale up embedding lookups when the embedding table 377 variable is large. 378 379 When a partitioned variable is saved to a `SavedModel`, it will be saved as if 380 it is one single variable. This improves serving efficiency by eliminating 381 a number of Ops that handle the partiton aspects. 382 383 Known limitations of variable partitioning: 384 385 * Number of partitions must not change across Checkpoint saving/loading. 386 387 * After saving partitioned variables to a SavedModel, the SavedModel can't be 388 loaded via `tf.saved_model.load`. 389 390 * Partition variable doesn't directly work with `tf.GradientTape`, please use 391 the `variables` attributes to get the actual variable components and use 392 them in gradient APIs instead. 393 394 __Dataset preparation__ 395 396 With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is 397 created in each of the workers to be used for training. This is done by 398 creating a `dataset_fn` that takes no argument and returns a 399 `tf.data.Dataset`, and passing the `dataset_fn` into 400 `tf.distribute.experimental.coordinator. 401 ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be 402 shuffled and repeated to have the examples run through the training as evenly 403 as possible. 404 405 ```python 406 def dataset_fn(): 407 filenames = ... 408 dataset = tf.data.Dataset.from_tensor_slices(filenames) 409 410 # Dataset is recommended to be shuffled, and repeated. 411 return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...) 412 413 coordinator = 414 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...) 415 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 416 ``` 417 418 __Limitations__ 419 420 * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental, 421 and the API is subject to further changes. 422 423 * When using `Model.fit`, `tf.distribute.experimental.ParameterServerStrategy` 424 must be used with a `tf.keras.utils.experimental.DatasetCreator`, and 425 `steps_per_epoch` must be specified. 426 """ 427 428 # pyformat: disable 429 def __init__(self, cluster_resolver, variable_partitioner=None): 430 """Initializes the TF2 parameter server strategy. 431 432 This initializes the `tf.distribute.experimental.ParameterServerStrategy` 433 object to be ready for use with 434 `tf.distribute.experimental.coordinator.ClusterCoordinator`. 435 436 Args: 437 cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` 438 object. 439 variable_partitioner: 440 a `distribute.experimental.partitioners.Partitioner` that specifies 441 how to partition variables. If `None`, variables will not be 442 partitioned. 443 444 * Predefined partitioners in `tf.distribute.experimental.partitioners` 445 can be used for this argument. A commonly used partitioner is 446 `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`, 447 which allocates at least 256K per shard, and each ps gets at most one 448 shard. 449 450 * `variable_partitioner` will be called for each variable created under 451 strategy `scope` to instruct how the variable should be partitioned. 452 Variables that have only one partition along the partitioning axis 453 (i.e., no need for partition) will be created as a normal `tf.Variable`. 454 455 * Only the first / outermost axis partitioning is supported. 456 457 * Div partition strategy is used to partition variables. Assuming we 458 assign consecutive integer ids along the first axis of a variable, then 459 ids are assigned to shards in a contiguous manner, while attempting to 460 keep each shard size identical. If the ids do not evenly divide the 461 number of shards, each of the first several shards will be assigned one 462 more id. For instance, a variable whose first dimension is 13 has 13 463 ids, and they are split across 5 shards as: 464 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 465 466 * Variables created under `strategy.extended.colocate_vars_with` will 467 not be partitioned. 468 """ 469 # pyformat: enable 470 self._cluster_resolver = cluster_resolver 471 472 self._verify_args_and_config(cluster_resolver) 473 self._cluster_coordinator = None 474 logging.info( 475 "`tf.distribute.experimental.ParameterServerStrategy` is initialized " 476 "with cluster_spec: %s", cluster_resolver.cluster_spec()) 477 478 # TODO(b/167894802): Make coordinator, worker, and ps names customizable. 479 self._connect_to_cluster(coordinator_name="chief") 480 self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver, 481 variable_partitioner) 482 super(ParameterServerStrategyV2, self).__init__(self._extended) 483 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 484 "ParameterServerStrategy") 485 self._should_use_with_coordinator = True 486 # Used while constructing distributed iterators. 487 self._canonicalize_devices = False 488 489 def _connect_to_cluster(self, coordinator_name): 490 if coordinator_name in ["worker", "ps"]: 491 raise ValueError("coordinator name should not be 'worker' or 'ps'.") 492 cluster_spec = self._cluster_resolver.cluster_spec() 493 self._num_workers = len(cluster_spec.as_dict().get("worker", ())) 494 self._num_ps = len(cluster_spec.as_dict().get("ps", ())) 495 496 device_filters = server_lib.ClusterDeviceFilters() 497 # For any worker, only the devices on ps and coordinator nodes are visible 498 for i in range(self._num_workers): 499 device_filters.set_device_filters( 500 "worker", i, ["/job:ps", "/job:%s" % coordinator_name]) 501 # Similarly for any ps, only the devices on workers and coordinator are 502 # visible 503 for i in range(self._num_ps): 504 device_filters.set_device_filters( 505 "ps", i, ["/job:worker", "/job:%s" % coordinator_name]) 506 507 # Allow at most one outstanding RPC for each worker at a certain time. This 508 # is to simplify worker failure handling in the runtime 509 os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False" 510 511 # Disable async executors to make context.async_wait a no-op. This avoids 512 # sending RPCs to remote workers since the executors used by PSStrategy 513 # are known to be always synchronous. 514 os.environ["TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY"] = "True" 515 516 logging.info("%s is now connecting to cluster with cluster_spec: %r", 517 self.__class__.__name__, cluster_spec) 518 remote.connect_to_cluster( 519 cluster_spec, 520 job_name=coordinator_name, 521 protocol=self._cluster_resolver.rpc_layer, 522 cluster_device_filters=device_filters) 523 524 distribute_lib.distribution_strategy_replica_gauge.get_cell( 525 "ps_strategy_num_workers").set(self._num_workers) 526 distribute_lib.distribution_strategy_replica_gauge.get_cell( 527 "ps_strategy_num_ps").set(self._num_ps) 528 529 def _verify_args_and_config(self, cluster_resolver): 530 if not cluster_resolver.cluster_spec(): 531 raise ValueError("Cluster spec must be non-empty in " 532 "`tf.distribute.cluster_resolver.ClusterResolver`.") 533 cluster_spec = cluster_resolver.cluster_spec() 534 535 # The following checks if the task types are allowed (chief, ps, worker). 536 multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access 537 cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id) 538 539 if multi_worker_util.task_count(cluster_spec, "ps") < 1: 540 raise ValueError("There must be at least one ps.") 541 542 if multi_worker_util.task_count(cluster_spec, "worker") < 1: 543 raise ValueError("There must be at least one worker.") 544 545 546class ParameterServerStrategyV2Extended( 547 parameter_server_strategy.ParameterServerStrategyExtended): 548 """Extended class for ParameterServerStrategyV2. 549 550 Please see `tf.distribute.StrategyExtended` doc for more information. 551 """ 552 553 def __init__(self, container_strategy, cluster_resolver, 554 variable_partitioner): 555 """Initialization of ParameterServerStrategyV2Extended.""" 556 super(ParameterServerStrategyV2Extended, self).__init__(container_strategy) 557 self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", [])) 558 self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get( 559 "worker", [])) 560 self._variable_count = 0 561 562 self._variable_partitioner = variable_partitioner 563 # The following two attrs are to verify that `ParameterServerStrategy` 564 # methods are properly used with a `ClusterCoordinator`. 565 self._used_with_coordinator = False 566 self._being_scheduled = False 567 self._set_num_gpus() 568 distribute_lib.distribution_strategy_replica_gauge.get_cell( 569 "num_gpus_per_worker").set(self._num_gpus_per_worker) 570 571 # Don't canonicalize the devices here since this code is executed on Chief, 572 # but we want the reduce evaluation to be done on each worker. Placer will 573 # automatically choose the right device based on current context. 574 # TODO(ishark): Use select_cross_device_ops instead. 575 self._cross_device_ops = cross_device_ops_lib.ReductionToOneDevice( 576 reduce_to_device="/device:CPU:0") 577 self._cross_device_ops._canonicalize_devices = False # pylint: disable=protected-access 578 self._allow_run_without_coordinator = False 579 self._coordinator_creation_lock = threading.Lock() 580 581 def _set_num_gpus(self): 582 devices = config.list_logical_devices("GPU") 583 per_worker_gpus = {} 584 for d in devices: 585 d_spec = tf_device.DeviceSpec.from_string(d.name) 586 if d_spec.device_type == "GPU" and d_spec.job == "worker": 587 # TODO(b/167894802): update if worker name is customizable 588 job_spec = d_spec.replace(device_type=None, device_index=None) 589 per_worker_gpus[job_spec] = per_worker_gpus.get(job_spec, 0) + 1 590 591 num_gpus = 0 592 for _, count in per_worker_gpus.items(): 593 if num_gpus > 0 and count != num_gpus: 594 raise ValueError("Mismatched number of GPUs per worker") 595 num_gpus = count 596 597 self._num_gpus_per_worker = num_gpus 598 logging.info(f"Number of GPUs on workers: {self._num_gpus_per_worker}") 599 600 @property 601 def _num_replicas_in_sync(self): 602 return self._num_gpus_per_worker or 1 603 604 def _create_var_creator(self, next_creator, **kwargs): 605 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 606 607 def var_creator(**kwargs): 608 """Create an AggregatingVariable.""" 609 # Create and wrap the variable. 610 v = next_creator(**kwargs) 611 wrapped_v = ps_values.CachingVariable(v) 612 wrapped = ps_values.AggregatingVariable(self._container_strategy(), 613 wrapped_v, aggregation) 614 return wrapped 615 616 if self._num_replicas_in_sync > 1: 617 if aggregation not in (vs.VariableAggregation.NONE, 618 vs.VariableAggregation.SUM, 619 vs.VariableAggregation.MEAN, 620 vs.VariableAggregation.ONLY_FIRST_REPLICA): 621 raise ValueError("Invalid variable aggregation mode: " + aggregation + 622 " for variable: " + kwargs["name"]) 623 return var_creator 624 else: 625 626 def variable_creator_single_replica(**kwargs): 627 v = next_creator(**kwargs) 628 return ps_values.CachingVariable(v) 629 630 return variable_creator_single_replica 631 632 def _create_variable(self, next_creator, **kwargs): 633 """Implements StrategyExtendedV2._create_variable. 634 635 Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be 636 created if satisfying all the following criteria: 637 1. `self._variable_partitioner` results in more than one partition on the 638 first axis. 639 2. variable's rank is greater than 0. 640 3. variable is not colocated with another variable. 641 Otherwise a `Variable` will be created. 642 643 Args: 644 next_creator: See `variable_scope.variable_creator_scope`; the next 645 creator in the chain. 646 **kwargs: Passed through to the next creator. 647 648 Returns: 649 A `Variable` or `ShardedVariable`. 650 """ 651 652 var_creator = self._create_var_creator(next_creator, **kwargs) 653 if "colocate_with" in kwargs: # Never partition colocated_with variables. 654 colocate_with = kwargs["colocate_with"] 655 # Clear the variable scope to avoid possible conflicts between device 656 # scope and colocation scope. 657 with ops.device(None): 658 with ops.colocate_with(colocate_with): 659 var = var_creator(**kwargs) 660 logging.debug( 661 "Creating variable (name:%s, shape:%r) that colocates with %s", 662 var.name, var.shape, kwargs["colocate_with"].name) 663 return var 664 665 if self._variable_partitioner is None: 666 return self._create_variable_round_robin(var_creator, **kwargs) 667 668 name = kwargs.get("name", None) 669 dtype = kwargs.get("dtype", None) 670 shape = kwargs.get("shape", None) 671 initial_value = kwargs.get("initial_value", None) 672 if initial_value is None: 673 # If we are loading, next_creator will return an UninitializedVariable 674 v = next_creator(**kwargs) 675 if not isinstance(v, resource_variable_ops.UninitializedVariable): 676 raise ValueError( 677 "It looks like you are using `ParameterServerStrategy` with a " 678 "`variable_partitioner`, and trying to create a variable without " 679 "specifying `initial_value`. This is not allowed. Please specify the " 680 "`initial_value`.") 681 elif shape is None or dtype is None: 682 raise ValueError( 683 "It looks like you are trying to load a `SavedModel` using " 684 "`tf.saved_model.load` within a `ParameterServerStrategy` scope, " 685 "but the `SavedModel` is missing shape or dtype information.") 686 else: 687 def initializer(shape, dtype, **kwargs): 688 if "partition_shape" in kwargs: 689 shape = kwargs["partition_shape"] 690 return array_ops.zeros(shape, dtype) 691 initial_value = functools.partial(initializer, shape=shape, dtype=dtype) 692 693 # Two cases where initial_value can be a callable: 694 # 1. initial_value is passed as a callable, e.g, an `initializer` class. 695 # 2. restoring from checkpoint, initial_value is a 696 # "CheckpointInitialValueCallable". 697 init_from_fn = callable(initial_value) 698 699 if init_from_fn and (shape is None or dtype is None): 700 init_from_fn = False 701 initial_value = initial_value() 702 if not init_from_fn: 703 # The initial_value is created on coordinator, it will need to be sent to 704 # ps for variable initialization, which can be inefficient and can 705 # potentially hit the 2GB limit on protobuf serialization. 706 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) 707 dtype = initial_value.dtype 708 shape = initial_value.shape 709 else: 710 shape = tensor_shape.as_shape(shape) 711 712 if shape.rank == 0: # Skip partitioning rank-0 variable. 713 return self._create_variable_round_robin(var_creator, **kwargs) 714 715 num_partitions = self._variable_partitioner(shape=shape, dtype=dtype) 716 if not num_partitions or num_partitions[0] == 0 or any( 717 v != 1 for v in num_partitions[1:]): 718 raise ValueError( 719 "variable_partitioner must return a list/tuple whose elements are 1" 720 " besides the first element (non-zero), got: %r" % num_partitions) 721 722 if num_partitions[0] == 1: # no partition 723 return self._create_variable_round_robin(var_creator, **kwargs) 724 725 # Use "div" partition strategy to partition the variable. 726 num_partitions = min(num_partitions[0], shape[0]) 727 base = shape[0] // num_partitions 728 extra = shape[0] % num_partitions 729 # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2] 730 # offsets: [0, 3, 6, 8, 10] 731 offsets = [] 732 for i in range(num_partitions): 733 if i == 0: 734 offsets.append(0) 735 else: 736 prev_shard_size = base + (1 if i - 1 < extra else 0) 737 offsets.append(offsets[i - 1] + prev_shard_size) 738 offsets.append(shape[0]) 739 740 def init_shard_fn(shard_index): 741 if not init_from_fn: 742 logging.log_if( 743 logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and 744 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 745 return initial_value[offsets[shard_index]:offsets[shard_index + 1]] 746 partition_shape = (offsets[shard_index + 1] - 747 offsets[shard_index],) + shape[1:] 748 partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:]) 749 arg_spec = tf_inspect.getfullargspec(initial_value) 750 if ("shard_info" not in arg_spec.args and 751 "shard_info" not in arg_spec.kwonlyargs): 752 try: 753 value = initial_value( 754 partition_shape=partition_shape, 755 partition_offset=partition_offset) 756 except (TypeError, ValueError): 757 # TypeError: Initializer doesn't accept kwargs 758 # ValueError: Initializer doesn't accept partition kwargs 759 # In both cases we go ahead creating the full value and then slice. 760 value = initial_value() 761 762 if value.shape == partition_shape: 763 # Initializer supports partition: value is the partition value. 764 return value 765 else: 766 # Initializer doesn't support partition: value is the full value 767 # and needs to be sliced to get the partition value. 768 logging.log_if( 769 logging.WARN, _INEFFICIENT_INIT_WARNING % name, 770 shard_index == 0 and 771 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 772 return value[offsets[shard_index]:offsets[shard_index + 1]] 773 else: 774 # For compatibility with `CheckpointInitialValueCallable`. 775 return initial_value( 776 shard_info=trackable.ShardInfo( 777 shape=tensor_shape.as_shape(partition_shape), 778 offset=partition_offset)) 779 780 var_list = [] 781 for i in range(num_partitions): 782 kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:] 783 kwargs["initial_value"] = lambda: init_shard_fn(i) 784 if name is not None: 785 kwargs["name"] = "{}/part_{}".format(name, i) 786 var_list.append(self._create_variable_round_robin(var_creator, **kwargs)) 787 788 result = sharded_variable.ShardedVariable(var_list) 789 return result 790 791 def _create_variable_round_robin(self, next_creator, **kwargs): 792 # Clear the colocation scope to avoid possible conflicts between device 793 # scope and colocation scope. 794 with ops.colocate_with(None, ignore_existing=True): 795 # Explicitly set CPU:0 device for PS in case create variable is called 796 # inside replica_fn and worker has with GPU:0 scope. 797 with ops.device("/job:ps/task:%d/device:CPU:0" % 798 (self._variable_count % self._num_ps)): 799 var = next_creator(**kwargs) 800 logging.debug( 801 "Creating variable (name:%s, shape:%r) on " 802 "/job:ps/task:%d/device:CPU:0", var.name, var.shape, 803 (self._variable_count % self._num_ps)) 804 self._variable_count += 1 805 return var 806 807 def _resource_creator_scope(self): 808 809 with self._coordinator_creation_lock: 810 if not self._container_strategy()._cluster_coordinator: # pylint: disable=protected-access 811 cluster_coordinator.ClusterCoordinator( 812 strategy=self._container_strategy()) 813 814 # TODO(wxinyi): We should warn the user of the inefficiency of creating 815 # `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to be 816 # distributed with `distribute_datasets_from_function` and 817 # `create_per_worker_dataset`. This is because the `dataset_fn` does not 818 # use the same `default_graph` as `scope` to which the 819 # `resource_creator_stack` belongs. Thus, `StaticHashTable` creation inside 820 # `dataset_fn` is not intercepted. And since its resource creation under a 821 # `tf.function` is lifted out, all workers will share the same resource on 822 # the coordinator which incurs worker-coordinator communication overhead. 823 824 def lookup_creator(next_creator, *args, **kwargs): 825 if load_context.in_load_context(): 826 return (ps_values.RestoredDistributedTable( 827 self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access 828 else: 829 return ps_values.DistributedTable(self._container_strategy(), 830 lambda: next_creator(*args, **kwargs)) # pylint: disable=protected-access 831 832 def restored_lookup_creator(next_creator, *args, **kwargs): 833 return (ps_values.RestoredDistributedTable( 834 self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access 835 836 return [ 837 ops.resource_creator_scope("StaticHashTable", lookup_creator), 838 ops.resource_creator_scope("RestoredStaticHashTable", 839 restored_lookup_creator) 840 ] 841 842 def _assert_used_with_cluster_coordinator(self): 843 if (not self._used_with_coordinator and 844 not self._allow_run_without_coordinator): 845 raise NotImplementedError( 846 "`tf.distribute.experimental.ParameterServerStrategy` must be used " 847 "with `tf.distribute.experimental.coordinator.ClusterCoordinator` in " 848 "a custom training loop. If you are using `Model.fit`, please supply " 849 "a dataset function directly to a " 850 "`tf.keras.utils.experimental.DatasetCreator` instead.") 851 852 def _assert_being_scheduled_by_cluster_coordinator(self): 853 if not self._being_scheduled and not self._allow_run_without_coordinator: 854 logging.warning( 855 "A `tf.distribute.experimental.ParameterServerStrategy` method is " 856 "invoked without using `ClusterCoordinator.schedule`. If you are not " 857 "tracing a tf.function, this method is possibly executed on the " 858 "coordinator, which can be slow. To properly dispatch functions to " 859 "run on workers, methods like `run` or `reduce` should be used " 860 "within a function passed to `tf.distribute.experimental.coordinator." 861 "ClusterCoordinator.schedule`.") 862 863 # options is not used right now. But we may want to support options while 864 # creating InputWorkers in future, similar to MirroredStrategy. 865 def _input_workers_with_options(self, options=None): 866 input_workers_devices = (("/device:CPU:0", self.worker_devices),) 867 return input_lib.InputWorkers( 868 input_workers_devices, canonicalize_devices=False) 869 870 def _experimental_distribute_dataset(self, dataset, options): 871 input_workers_devices = self._input_workers_with_options() 872 873 # If this DistributedDataset is created outside ClusterCoordinator, i,e, 874 # outside a tf.function, we don't build its underlying datasets immediately 875 # until it is passed to ClusterCoordinator.create_per_worker_dataset. 876 return input_util.get_distributed_dataset( 877 dataset, 878 input_workers_devices, 879 self._container_strategy(), 880 num_replicas_in_sync=self._num_replicas_in_sync, 881 options=options, 882 build=ops.inside_function()) # will be built by ClusterCoordinator 883 884 def _distribute_datasets_from_function(self, dataset_fn, options): 885 # There is no synchronization beyond a worker and thus, the number of 886 # input pipelines in sync is only 1 per worker. 887 input_pipeline_id_in_sync = 0 888 num_input_pipelines_in_sync = 1 889 890 input_context = distribute_lib.InputContext( 891 num_input_pipelines=num_input_pipelines_in_sync, 892 input_pipeline_id=input_pipeline_id_in_sync, 893 num_replicas_in_sync=self._num_replicas_in_sync) 894 895 # If this DistributedDatasetFromFunction is created outside 896 # ClusterCoordinator, i,e, outside a tf.function, we don't build its 897 # underlying datasets immediately until it is passed to 898 # ClusterCoordinator.create_per_worker_dataset. 899 return input_util.get_distributed_datasets_from_function( 900 dataset_fn, 901 self._input_workers_with_options(options), [input_context], 902 self._container_strategy(), 903 options=options, 904 build=ops.inside_function()) # will be built by ClusterCoordinator 905 906 @property 907 def worker_devices(self): 908 num_gpus = self._num_gpus_per_worker 909 if num_gpus > 0: 910 compute_devices = tuple("/device:GPU:%d" % (i,) for i in range(num_gpus)) 911 else: 912 compute_devices = ("/device:CPU:0",) 913 return compute_devices 914 915 def _call_for_each_replica(self, fn, args, kwargs): 916 self._assert_being_scheduled_by_cluster_coordinator() 917 918 return mirrored_run.call_for_each_replica(self._container_strategy(), fn, 919 args, kwargs) 920 921 def _reduce(self, reduce_op, value): 922 self._assert_being_scheduled_by_cluster_coordinator() 923 dst = device_util.current() or self._default_device or "/device:CPU:0" 924 destinations = device_util.canonicalize_without_job_and_task(dst) 925 result = self._local_results( 926 self.reduce_to(reduce_op, value, destinations))[0] 927 return result 928 929 def _reduce_to(self, reduce_op, value, destinations, options): 930 self._assert_being_scheduled_by_cluster_coordinator() 931 932 def get_values(x): 933 if isinstance(x, values.DistributedValues): 934 return self._cross_device_ops.reduce( 935 reduce_op, x, destinations=destinations) # pylint: disable=protected-access 936 return x 937 938 return nest.map_structure(get_values, value) 939 940 941# The warning that will be logged if the way we initialize sharded variables 942# is memory-inefficient. 943_INEFFICIENT_INIT_WARNING = ( 944 "Large variable %s is partitioned but not initialized in a " 945 "memory-efficient way. On each shard, the full value is first being " 946 "created and then sliced into smaller values. To reduce the memory " 947 "footprint, explicitly specify `dtype` and `shape` when creating " 948 "variables, and use `tf.initializers` to initialize the variable. " 949 "Note that some initializers (e.g., orthogonal) don't support " 950 "memory-efficient initialization and there is not much you can do here.") 951 952_LARGE_VARIABLE_NUM_ELEMENTS = 1e9 953