xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/parameter_server_strategy_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Parameter server strategy V2 class.
16
17This is currently under development and the API is subject to change.
18"""
19
20import functools
21import os
22import threading
23
24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import input_lib
28from tensorflow.python.distribute import input_util
29from tensorflow.python.distribute import mirrored_run
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.distribute import parameter_server_strategy
32from tensorflow.python.distribute import ps_values
33from tensorflow.python.distribute import sharded_variable
34from tensorflow.python.distribute import values
35from tensorflow.python.eager import remote
36from tensorflow.python.framework import config
37from tensorflow.python.framework import device as tf_device
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import variable_scope as vs
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.trackable import base as trackable
45from tensorflow.python.training import server_lib
46from tensorflow.python.util import nest
47from tensorflow.python.util import tf_inspect
48from tensorflow.python.util.lazy_loader import LazyLoader
49from tensorflow.python.util.tf_export import tf_export
50
51ALLOWED_TASK_TYPES = ("chief", "worker", "ps")
52
53cluster_coordinator = LazyLoader(
54    "cluster_coordinator", globals(),
55    "tensorflow.python.distribute.coordinator.cluster_coordinator"
56)
57
58load_context = LazyLoader(
59    "load_context", globals(),
60    "tensorflow.python.keras.saving.saved_model.load_context"
61)
62
63
64@tf_export(
65    "distribute.experimental.ParameterServerStrategy",
66    "distribute.ParameterServerStrategy",
67    v1=[])
68class ParameterServerStrategyV2(distribute_lib.Strategy):
69  """An multi-worker tf.distribute strategy with parameter servers.
70
71  Parameter server training is a common data-parallel method to scale up a
72  machine learning model on multiple machines. A parameter server training
73  cluster consists of workers and parameter servers. Variables are created on
74  parameter servers and they are read and updated by workers in each step.
75  By default, workers read and update these variables independently without
76  synchronizing with each other. Under this configuration, it is known as
77  asynchronous training.
78
79  In TensorFlow 2, we recommend an architecture based on central coordination
80  for parameter server training. Each worker and parameter server runs a
81  `tf.distribute.Server`, and on top of that, a coordinator task is responsible
82  for creating resources on workers and parameter servers, dispatching
83  functions, and coordinating the training. The coordinator uses a
84  `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
85  cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define
86  variables on parameter servers and computation on workers.
87
88  For the training to work, the coordinator dispatches `tf.function`s to be
89  executed on remote workers. Upon receiving requests from the coordinator, a
90  worker executes the `tf.function` by reading the variables from parameter
91  servers, executing the ops, and updating the variables on the parameter
92  servers. Each of the worker only processes the requests from the coordinator,
93  and communicates with parameter servers, without direct interactions with
94  other workers in the cluster.
95
96  As a result, failures of some workers do not prevent the cluster from
97  continuing the work, and this allows the cluster to train with instances that
98  can be occasionally unavailable (e.g. preemptible or spot instances). The
99  coordinator and parameter servers though, must be available at all times for
100  the cluster to make progress.
101
102  Note that the coordinator is not one of the training workers. Instead, it
103  creates resources such as variables and datasets, dispatches `tf.function`s,
104  saves checkpoints and so on. In addition to workers, parameter servers and
105  the coordinator, an optional evaluator can be run on the side that
106  periodically reads the checkpoints saved by the coordinator and runs
107  evaluations against each checkpoint.
108
109  `ParameterServerStrategy` is supported with two training APIs: [Custom
110  Training Loop (CTL)]
111  (https://www.tensorflow.org/tutorials/distribute/custom_training)
112  and [Keras Training API, also known as `Model.fit`]
113  (https://www.tensorflow.org/tutorials/distribute/keras). CTL is recommended
114  when users prefer to define the details of their training loop, and
115  `Model.fit` is recommended when users prefer a high-level abstraction and
116  handling of training.
117
118  When using a CTL, `ParameterServerStrategy` has to work in conjunction with a
119  `tf.distribute.experimental.coordinator.ClusterCoordinator` object.
120
121  When using `Model.fit`, currently only the
122  `tf.keras.utils.experimental.DatasetCreator` input type is supported.
123
124  __Example code for coordinator__
125
126  This section provides code snippets that are intended to be run on (the only)
127  one task that is designated as the coordinator. Note that `cluster_resolver`,
128  `variable_partitioner`, and `dataset_fn` arguments are explained in the
129  following "Cluster setup", "Variable partitioning", and "Dataset preparation"
130  sections.
131
132  With a CTL,
133
134  ```python
135  # Prepare a strategy to use with the cluster and variable partitioning info.
136  strategy = tf.distribute.experimental.ParameterServerStrategy(
137      cluster_resolver=...,
138      variable_partitioner=...)
139  coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
140      strategy=strategy)
141
142  # Prepare a distribute dataset that will place datasets on the workers.
143  distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
144
145  with strategy.scope():
146    model = ...
147    optimizer, metrics = ...  # Keras optimizer/metrics are great choices
148    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
149    checkpoint_manager = tf.train.CheckpointManager(
150        checkpoint, checkpoint_dir, max_to_keep=2)
151    # `load_checkpoint` infers initial epoch from `optimizer.iterations`.
152    initial_epoch = load_checkpoint(checkpoint_manager) or 0
153
154  @tf.function
155  def worker_fn(iterator):
156
157    def replica_fn(inputs):
158      batch_data, labels = inputs
159      # calculate gradient, applying gradient, metrics update etc.
160
161    strategy.run(replica_fn, args=(next(iterator),))
162
163  for epoch in range(initial_epoch, num_epoch):
164    distributed_iterator = iter(distributed_dataset)  # Reset iterator state.
165    for step in range(steps_per_epoch):
166
167      # Asynchronously schedule the `worker_fn` to be executed on an arbitrary
168      # worker. This call returns immediately.
169      coordinator.schedule(worker_fn, args=(distributed_iterator,))
170
171    # `join` blocks until all scheduled `worker_fn`s finish execution. Once it
172    # returns, we can read the metrics and save checkpoints as needed.
173    coordinator.join()
174    logging.info('Metric result: %r', metrics.result())
175    train_accuracy.reset_states()
176    checkpoint_manager.save()
177  ```
178
179  With `Model.fit`,
180
181  ```python
182  # Prepare a strategy to use with the cluster and variable partitioning info.
183  strategy = tf.distribute.experimental.ParameterServerStrategy(
184      cluster_resolver=...,
185      variable_partitioner=...)
186
187  # A dataset function takes a `input_context` and returns a `Dataset`
188  def dataset_fn(input_context):
189    dataset = tf.data.Dataset.from_tensors(...)
190    return dataset.repeat().shard(...).batch(...).prefetch(...)
191
192  # With `Model.fit`, a `DatasetCreator` needs to be used.
193  input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...)
194
195  with strategy.scope():
196    model = ...  # Make sure the `Model` is created within scope.
197  model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...)
198
199  # Optional callbacks to checkpoint the model, back up the progress, etc.
200  callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...]
201
202  # `steps_per_epoch` is required with `ParameterServerStrategy`.
203  model.fit(input, epochs=..., steps_per_epoch=..., callbacks=callbacks)
204  ```
205
206  __Example code for worker and parameter servers__
207
208  In addition to the coordinator, there should be tasks designated as
209  "worker" or "ps". They should run the following code to start a TensorFlow
210  server, waiting for coordinator's requests:
211
212  ```python
213  # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
214  # the cluster information. See below "Cluster setup" section.
215  cluster_resolver = ...
216
217  server = tf.distribute.Server(
218      cluster_resolver.cluster_spec(),
219      job_name=cluster_resolver.task_type,
220      task_index=cluster_resolver.task_id,
221      protocol="grpc")
222
223  # Blocking the process that starts a server from exiting.
224  server.join()
225  ```
226
227  __Cluster setup__
228
229  In order for the tasks in the cluster to know other tasks' addresses,
230  a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
231  in coordinator, worker, and ps. The
232  `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
233  the cluster information, as well as the task type and id of the current task.
234  See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
235
236  If `TF_CONFIG` environment variable is set, a
237  `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as
238  well.
239
240  Since there are assumptions in
241  `tf.distribute.experimental.ParameterServerStrategy` around the naming of the
242  task types, "chief", "ps", and "worker" should be used in the
243  `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator,
244  parameter servers, and workers, respectively.
245
246  The following example demonstrates setting `TF_CONFIG` for the task designated
247  as a parameter server (task type "ps") and index 1 (the second task), in a
248  cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs
249  to be set before the use of
250  `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
251
252  Example code for cluster setup:
253  ```python
254  os.environ['TF_CONFIG'] = '''
255  {
256    "cluster": {
257      "chief": ["chief.example.com:2222"],
258      "ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
259      "worker": ["worker0.example.com:2222", "worker1.example.com:2222",
260                 "worker2.example.com:2222"]
261    },
262    "task": {
263      "type": "ps",
264      "index": 1
265    }
266  }
267  '''
268  ```
269
270  If you prefer to run the same binary for all tasks, you will need to let the
271  binary branch into different roles at the beginning of the program:
272  ```python
273  # If coordinator, create a strategy and start the training program.
274  if cluster_resolver.task_type == 'chief':
275    strategy = tf.distribute.experimental.ParameterServerStrategy(
276        cluster_resolver)
277    ...
278
279  # If worker/ps, create a server
280  elif cluster_resolver.task_type in ("worker", "ps"):
281    server = tf.distribute.Server(...)
282    ...
283  ```
284  Alternatively, you can also start a bunch of TensorFlow servers in advance and
285  connect to them later. The coordinator can be in the same cluster or on any
286  machine that has connectivity to workers and parameter servers. This is
287  covered in our guide and tutorial.
288
289  __Variable creation with `strategy.scope()`__
290
291  `tf.distribute.experimental.ParameterServerStrategy` follows the
292  `tf.distribute` API contract where variable creation is expected to be inside
293  the context manager returned by `strategy.scope()`, in order to be correctly
294  placed on parameter servers in a round-robin manner:
295
296  ```python
297  # In this example, we're assuming having 3 ps.
298  strategy = tf.distribute.experimental.ParameterServerStrategy(
299      cluster_resolver=...)
300  coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
301      strategy=strategy)
302
303  # Variables should be created inside scope to be placed on parameter servers.
304  # If created outside scope such as `v1` here, it would be placed on the
305  # coordinator.
306  v1 = tf.Variable(initial_value=0.0)
307
308  with strategy.scope():
309    v2 = tf.Variable(initial_value=1.0)
310    v3 = tf.Variable(initial_value=2.0)
311    v4 = tf.Variable(initial_value=3.0)
312    v5 = tf.Variable(initial_value=4.0)
313
314  # v2 through v5 are created in scope and are distributed on parameter servers.
315  # Default placement is round-robin but the order should not be relied on.
316  assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
317  assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
318  assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
319  assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
320  ```
321
322  See `distribute.Strategy.scope` for more information.
323
324  __Variable partitioning__
325
326  Having dedicated servers to store variables means being able to divide up, or
327  "shard" the variables across the ps. Partitioning large variable among ps is a
328  commonly used technique to boost training throughput and mitigate memory
329  constraints. It enables parallel computations and updates on different shards
330  of a variable, and often yields better load balancing across parameter
331  servers. Without sharding, models with large variables (e.g, embeddings) that
332  can't fit into one machine's memory would otherwise be unable to train.
333
334  With `tf.distribute.experimental.ParameterServerStrategy`, if a
335  `variable_partitioner` is provided to `__init__` and certain conditions are
336  satisfied, the resulting variables created in scope are sharded across the
337  parameter servers, in a round-robin fashion. The variable reference returned
338  from `tf.Variable` becomes a type that serves as the container of the sharded
339  variables. One can access `variables` attribute of this container for the
340  actual variable components. If building model with `tf.Module` or Keras,
341  the variable components are collected in the `variables` alike attributes.
342
343  It is recommended to use size-based partitioners like
344  `tf.distribute.experimental.partitioners.MinSizePartitioner` to avoid
345  partitioning small variables, which could have negative impact on model
346  training speed.
347
348  ```python
349  # Partition the embedding layer into 2 shards.
350  variable_partitioner = (
351    tf.distribute.experimental.partitioners.MinSizePartitioner(
352      min_shard_bytes=(256 << 10),
353      max_shards = 2))
354  strategy = tf.distribute.experimental.ParameterServerStrategy(
355    cluster_resolver=...,
356    variable_partitioner = variable_partitioner)
357  with strategy.scope():
358    embedding = tf.keras.layers.Embedding(input_dim=1024, output_dim=1024)
359  assert len(embedding.variables) == 2
360  assert isinstance(embedding.variables[0], tf.Variable)
361  assert isinstance(embedding.variables[1], tf.Variable)
362  assert embedding.variables[0].shape == (512, 1024)
363  assert embedding.variables[1].shape == (512, 1024)
364  ```
365
366  The sharded variable container can be converted to a `Tensor` via
367  `tf.convert_to_tensor`. This means the container can be directly used in most
368  Python Ops where such `Tensor` conversion automatically happens. For example,
369  in the above code snippet, `x * self.w` would implicitly apply the said tensor
370  conversion. Note that such conversion can be expensive, as the variable
371  components need to be transferred from multiple parameter servers to where
372  the value is used.
373
374  `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor
375  conversion, and performs parallel lookups on the variable components instead.
376  This is crucial to scale up embedding lookups when the embedding table
377  variable is large.
378
379  When a partitioned variable is saved to a `SavedModel`, it will be saved as if
380  it is one single variable. This improves serving efficiency by eliminating
381  a number of Ops that handle the partiton aspects.
382
383  Known limitations of variable partitioning:
384
385  * Number of partitions must not change across Checkpoint saving/loading.
386
387  * After saving partitioned variables to a SavedModel, the SavedModel can't be
388    loaded via `tf.saved_model.load`.
389
390  * Partition variable doesn't directly work with `tf.GradientTape`, please use
391    the `variables` attributes to get the actual variable components and use
392    them in gradient APIs instead.
393
394  __Dataset preparation__
395
396  With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
397  created in each of the workers to be used for training. This is done by
398  creating a `dataset_fn` that takes no argument and returns a
399  `tf.data.Dataset`, and passing the `dataset_fn` into
400  `tf.distribute.experimental.coordinator.
401  ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
402  shuffled and repeated to have the examples run through the training as evenly
403  as possible.
404
405  ```python
406  def dataset_fn():
407    filenames = ...
408    dataset = tf.data.Dataset.from_tensor_slices(filenames)
409
410    # Dataset is recommended to be shuffled, and repeated.
411    return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
412
413  coordinator =
414      tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
415  distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
416  ```
417
418  __Limitations__
419
420  * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
421  and the API is subject to further changes.
422
423  * When using `Model.fit`, `tf.distribute.experimental.ParameterServerStrategy`
424  must be used with a `tf.keras.utils.experimental.DatasetCreator`, and
425  `steps_per_epoch` must be specified.
426  """
427
428  # pyformat: disable
429  def __init__(self, cluster_resolver, variable_partitioner=None):
430    """Initializes the TF2 parameter server strategy.
431
432    This initializes the `tf.distribute.experimental.ParameterServerStrategy`
433    object to be ready for use with
434    `tf.distribute.experimental.coordinator.ClusterCoordinator`.
435
436    Args:
437      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
438        object.
439      variable_partitioner:
440        a `distribute.experimental.partitioners.Partitioner` that specifies
441        how to partition variables. If `None`, variables will not be
442        partitioned.
443
444        * Predefined partitioners in `tf.distribute.experimental.partitioners`
445        can be used for this argument. A commonly used partitioner is
446        `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`,
447        which allocates at least 256K per shard, and each ps gets at most one
448        shard.
449
450        * `variable_partitioner` will be called for each variable created under
451        strategy `scope` to instruct how the variable should be partitioned.
452        Variables that have only one partition along the partitioning axis
453        (i.e., no need for partition) will be created as a normal `tf.Variable`.
454
455        * Only the first / outermost axis partitioning is supported.
456
457        * Div partition strategy is used to partition variables. Assuming we
458        assign consecutive integer ids along the first axis of a variable, then
459        ids are assigned to shards in a contiguous manner, while attempting to
460        keep each shard size identical. If the ids do not evenly divide the
461        number of shards, each of the first several shards will be assigned one
462        more id. For instance, a variable whose first dimension is 13 has 13
463        ids, and they are split across 5 shards as:
464        `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
465
466        * Variables created under `strategy.extended.colocate_vars_with` will
467        not be partitioned.
468    """
469    # pyformat: enable
470    self._cluster_resolver = cluster_resolver
471
472    self._verify_args_and_config(cluster_resolver)
473    self._cluster_coordinator = None
474    logging.info(
475        "`tf.distribute.experimental.ParameterServerStrategy` is initialized "
476        "with cluster_spec: %s", cluster_resolver.cluster_spec())
477
478    # TODO(b/167894802): Make coordinator, worker, and ps names customizable.
479    self._connect_to_cluster(coordinator_name="chief")
480    self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
481                                                       variable_partitioner)
482    super(ParameterServerStrategyV2, self).__init__(self._extended)
483    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
484        "ParameterServerStrategy")
485    self._should_use_with_coordinator = True
486    # Used while constructing distributed iterators.
487    self._canonicalize_devices = False
488
489  def _connect_to_cluster(self, coordinator_name):
490    if coordinator_name in ["worker", "ps"]:
491      raise ValueError("coordinator name should not be 'worker' or 'ps'.")
492    cluster_spec = self._cluster_resolver.cluster_spec()
493    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
494    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
495
496    device_filters = server_lib.ClusterDeviceFilters()
497    # For any worker, only the devices on ps and coordinator nodes are visible
498    for i in range(self._num_workers):
499      device_filters.set_device_filters(
500          "worker", i, ["/job:ps", "/job:%s" % coordinator_name])
501    # Similarly for any ps, only the devices on workers and coordinator are
502    # visible
503    for i in range(self._num_ps):
504      device_filters.set_device_filters(
505          "ps", i, ["/job:worker", "/job:%s" % coordinator_name])
506
507    # Allow at most one outstanding RPC for each worker at a certain time. This
508    # is to simplify worker failure handling in the runtime
509    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"
510
511    # Disable async executors to make context.async_wait a no-op. This avoids
512    # sending RPCs to remote workers since the executors used by PSStrategy
513    # are known to be always synchronous.
514    os.environ["TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY"] = "True"
515
516    logging.info("%s is now connecting to cluster with cluster_spec: %r",
517                 self.__class__.__name__, cluster_spec)
518    remote.connect_to_cluster(
519        cluster_spec,
520        job_name=coordinator_name,
521        protocol=self._cluster_resolver.rpc_layer,
522        cluster_device_filters=device_filters)
523
524    distribute_lib.distribution_strategy_replica_gauge.get_cell(
525        "ps_strategy_num_workers").set(self._num_workers)
526    distribute_lib.distribution_strategy_replica_gauge.get_cell(
527        "ps_strategy_num_ps").set(self._num_ps)
528
529  def _verify_args_and_config(self, cluster_resolver):
530    if not cluster_resolver.cluster_spec():
531      raise ValueError("Cluster spec must be non-empty in "
532                       "`tf.distribute.cluster_resolver.ClusterResolver`.")
533    cluster_spec = cluster_resolver.cluster_spec()
534
535    # The following checks if the task types are allowed (chief, ps, worker).
536    multi_worker_util._validate_cluster_spec(  # pylint: disable=protected-access
537        cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id)
538
539    if multi_worker_util.task_count(cluster_spec, "ps") < 1:
540      raise ValueError("There must be at least one ps.")
541
542    if multi_worker_util.task_count(cluster_spec, "worker") < 1:
543      raise ValueError("There must be at least one worker.")
544
545
546class ParameterServerStrategyV2Extended(
547    parameter_server_strategy.ParameterServerStrategyExtended):
548  """Extended class for ParameterServerStrategyV2.
549
550  Please see `tf.distribute.StrategyExtended` doc for more information.
551  """
552
553  def __init__(self, container_strategy, cluster_resolver,
554               variable_partitioner):
555    """Initialization of ParameterServerStrategyV2Extended."""
556    super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
557    self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
558    self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get(
559        "worker", []))
560    self._variable_count = 0
561
562    self._variable_partitioner = variable_partitioner
563    # The following two attrs are to verify that `ParameterServerStrategy`
564    # methods are properly used with a `ClusterCoordinator`.
565    self._used_with_coordinator = False
566    self._being_scheduled = False
567    self._set_num_gpus()
568    distribute_lib.distribution_strategy_replica_gauge.get_cell(
569        "num_gpus_per_worker").set(self._num_gpus_per_worker)
570
571    # Don't canonicalize the devices here since this code is executed on Chief,
572    # but we want the reduce evaluation to be done on each worker. Placer will
573    # automatically choose the right device based on current context.
574    # TODO(ishark): Use select_cross_device_ops instead.
575    self._cross_device_ops = cross_device_ops_lib.ReductionToOneDevice(
576        reduce_to_device="/device:CPU:0")
577    self._cross_device_ops._canonicalize_devices = False  # pylint: disable=protected-access
578    self._allow_run_without_coordinator = False
579    self._coordinator_creation_lock = threading.Lock()
580
581  def _set_num_gpus(self):
582    devices = config.list_logical_devices("GPU")
583    per_worker_gpus = {}
584    for d in devices:
585      d_spec = tf_device.DeviceSpec.from_string(d.name)
586      if d_spec.device_type == "GPU" and d_spec.job == "worker":
587        # TODO(b/167894802): update if worker name is customizable
588        job_spec = d_spec.replace(device_type=None, device_index=None)
589        per_worker_gpus[job_spec] = per_worker_gpus.get(job_spec, 0) + 1
590
591    num_gpus = 0
592    for _, count in per_worker_gpus.items():
593      if num_gpus > 0 and count != num_gpus:
594        raise ValueError("Mismatched number of GPUs per worker")
595      num_gpus = count
596
597    self._num_gpus_per_worker = num_gpus
598    logging.info(f"Number of GPUs on workers: {self._num_gpus_per_worker}")
599
600  @property
601  def _num_replicas_in_sync(self):
602    return self._num_gpus_per_worker or 1
603
604  def _create_var_creator(self, next_creator, **kwargs):
605    aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
606
607    def var_creator(**kwargs):
608      """Create an AggregatingVariable."""
609      # Create and wrap the variable.
610      v = next_creator(**kwargs)
611      wrapped_v = ps_values.CachingVariable(v)
612      wrapped = ps_values.AggregatingVariable(self._container_strategy(),
613                                              wrapped_v, aggregation)
614      return wrapped
615
616    if self._num_replicas_in_sync > 1:
617      if aggregation not in (vs.VariableAggregation.NONE,
618                             vs.VariableAggregation.SUM,
619                             vs.VariableAggregation.MEAN,
620                             vs.VariableAggregation.ONLY_FIRST_REPLICA):
621        raise ValueError("Invalid variable aggregation mode: " + aggregation +
622                         " for variable: " + kwargs["name"])
623      return var_creator
624    else:
625
626      def variable_creator_single_replica(**kwargs):
627        v = next_creator(**kwargs)
628        return ps_values.CachingVariable(v)
629
630      return variable_creator_single_replica
631
632  def _create_variable(self, next_creator, **kwargs):
633    """Implements StrategyExtendedV2._create_variable.
634
635    Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be
636    created if satisfying all the following criteria:
637      1. `self._variable_partitioner` results in more than one partition on the
638         first axis.
639      2. variable's rank is greater than 0.
640      3. variable is not colocated with another variable.
641    Otherwise a `Variable` will be created.
642
643    Args:
644      next_creator: See `variable_scope.variable_creator_scope`; the next
645        creator in the chain.
646      **kwargs: Passed through to the next creator.
647
648    Returns:
649      A `Variable` or `ShardedVariable`.
650    """
651
652    var_creator = self._create_var_creator(next_creator, **kwargs)
653    if "colocate_with" in kwargs:  # Never partition colocated_with variables.
654      colocate_with = kwargs["colocate_with"]
655      # Clear the variable scope to avoid possible conflicts between device
656      # scope and colocation scope.
657      with ops.device(None):
658        with ops.colocate_with(colocate_with):
659          var = var_creator(**kwargs)
660          logging.debug(
661              "Creating variable (name:%s, shape:%r) that colocates with %s",
662              var.name, var.shape, kwargs["colocate_with"].name)
663          return var
664
665    if self._variable_partitioner is None:
666      return self._create_variable_round_robin(var_creator, **kwargs)
667
668    name = kwargs.get("name", None)
669    dtype = kwargs.get("dtype", None)
670    shape = kwargs.get("shape", None)
671    initial_value = kwargs.get("initial_value", None)
672    if initial_value is None:
673      # If we are loading, next_creator will return an UninitializedVariable
674      v = next_creator(**kwargs)
675      if not isinstance(v, resource_variable_ops.UninitializedVariable):
676        raise ValueError(
677            "It looks like you are using `ParameterServerStrategy` with a "
678            "`variable_partitioner`, and trying to create a variable without "
679            "specifying `initial_value`. This is not allowed. Please specify the "
680            "`initial_value`.")
681      elif shape is None or dtype is None:
682        raise ValueError(
683            "It looks like you are trying to load a `SavedModel` using "
684            "`tf.saved_model.load` within a `ParameterServerStrategy` scope, "
685            "but the `SavedModel` is missing shape or dtype information.")
686      else:
687        def initializer(shape, dtype, **kwargs):
688          if "partition_shape" in kwargs:
689            shape = kwargs["partition_shape"]
690          return array_ops.zeros(shape, dtype)
691        initial_value = functools.partial(initializer, shape=shape, dtype=dtype)
692
693    # Two cases where initial_value can be a callable:
694    #   1. initial_value is passed as a callable, e.g, an `initializer` class.
695    #   2. restoring from checkpoint, initial_value is a
696    #     "CheckpointInitialValueCallable".
697    init_from_fn = callable(initial_value)
698
699    if init_from_fn and (shape is None or dtype is None):
700      init_from_fn = False
701      initial_value = initial_value()
702    if not init_from_fn:
703      # The initial_value is created on coordinator, it will need to be sent to
704      # ps for variable initialization, which can be inefficient and can
705      # potentially hit the 2GB limit on protobuf serialization.
706      initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
707      dtype = initial_value.dtype
708      shape = initial_value.shape
709    else:
710      shape = tensor_shape.as_shape(shape)
711
712    if shape.rank == 0:  # Skip partitioning rank-0 variable.
713      return self._create_variable_round_robin(var_creator, **kwargs)
714
715    num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
716    if not num_partitions or num_partitions[0] == 0 or any(
717        v != 1 for v in num_partitions[1:]):
718      raise ValueError(
719          "variable_partitioner must return a list/tuple whose elements are 1"
720          " besides the first element (non-zero), got: %r" % num_partitions)
721
722    if num_partitions[0] == 1:  # no partition
723      return self._create_variable_round_robin(var_creator, **kwargs)
724
725    # Use "div" partition strategy to partition the variable.
726    num_partitions = min(num_partitions[0], shape[0])
727    base = shape[0] // num_partitions
728    extra = shape[0] % num_partitions
729    # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
730    # offsets: [0, 3, 6, 8, 10]
731    offsets = []
732    for i in range(num_partitions):
733      if i == 0:
734        offsets.append(0)
735      else:
736        prev_shard_size = base + (1 if i - 1 < extra else 0)
737        offsets.append(offsets[i - 1] + prev_shard_size)
738    offsets.append(shape[0])
739
740    def init_shard_fn(shard_index):
741      if not init_from_fn:
742        logging.log_if(
743            logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and
744            shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
745        return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
746      partition_shape = (offsets[shard_index + 1] -
747                         offsets[shard_index],) + shape[1:]
748      partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:])
749      arg_spec = tf_inspect.getfullargspec(initial_value)
750      if ("shard_info" not in arg_spec.args and
751          "shard_info" not in arg_spec.kwonlyargs):
752        try:
753          value = initial_value(
754              partition_shape=partition_shape,
755              partition_offset=partition_offset)
756        except (TypeError, ValueError):
757          # TypeError: Initializer doesn't accept kwargs
758          # ValueError: Initializer doesn't accept partition kwargs
759          # In both cases we go ahead creating the full value and then slice.
760          value = initial_value()
761
762        if value.shape == partition_shape:
763          # Initializer supports partition: value is the partition value.
764          return value
765        else:
766          # Initializer doesn't support partition: value is the full value
767          # and needs to be sliced to get the partition value.
768          logging.log_if(
769              logging.WARN, _INEFFICIENT_INIT_WARNING % name,
770              shard_index == 0 and
771              shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
772          return value[offsets[shard_index]:offsets[shard_index + 1]]
773      else:
774        # For compatibility with `CheckpointInitialValueCallable`.
775        return initial_value(
776            shard_info=trackable.ShardInfo(
777                shape=tensor_shape.as_shape(partition_shape),
778                offset=partition_offset))
779
780    var_list = []
781    for i in range(num_partitions):
782      kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
783      kwargs["initial_value"] = lambda: init_shard_fn(i)
784      if name is not None:
785        kwargs["name"] = "{}/part_{}".format(name, i)
786      var_list.append(self._create_variable_round_robin(var_creator, **kwargs))
787
788    result = sharded_variable.ShardedVariable(var_list)
789    return result
790
791  def _create_variable_round_robin(self, next_creator, **kwargs):
792    # Clear the colocation scope to avoid possible conflicts between device
793    # scope and colocation scope.
794    with ops.colocate_with(None, ignore_existing=True):
795      # Explicitly set CPU:0 device for PS in case create variable is called
796      # inside replica_fn and worker has with GPU:0 scope.
797      with ops.device("/job:ps/task:%d/device:CPU:0" %
798                      (self._variable_count % self._num_ps)):
799        var = next_creator(**kwargs)
800        logging.debug(
801            "Creating variable (name:%s, shape:%r) on "
802            "/job:ps/task:%d/device:CPU:0", var.name, var.shape,
803            (self._variable_count % self._num_ps))
804        self._variable_count += 1
805        return var
806
807  def _resource_creator_scope(self):
808
809    with self._coordinator_creation_lock:
810      if not self._container_strategy()._cluster_coordinator:  # pylint: disable=protected-access
811        cluster_coordinator.ClusterCoordinator(
812            strategy=self._container_strategy())
813
814    # TODO(wxinyi): We should warn the user of the inefficiency of creating
815    # `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to be
816    # distributed with `distribute_datasets_from_function` and
817    # `create_per_worker_dataset`. This is because the `dataset_fn` does not
818    # use the same `default_graph` as `scope` to which the
819    # `resource_creator_stack` belongs. Thus, `StaticHashTable` creation inside
820    # `dataset_fn` is not intercepted. And since its resource creation under a
821    # `tf.function` is lifted out, all workers will share the same resource on
822    # the coordinator which incurs worker-coordinator communication overhead.
823
824    def lookup_creator(next_creator, *args, **kwargs):
825      if load_context.in_load_context():
826        return (ps_values.RestoredDistributedTable(
827            self._container_strategy(), lambda: next_creator(*args, **kwargs)))  # pylint: disable=protected-access
828      else:
829        return ps_values.DistributedTable(self._container_strategy(),
830                                          lambda: next_creator(*args, **kwargs))  # pylint: disable=protected-access
831
832    def restored_lookup_creator(next_creator, *args, **kwargs):
833      return (ps_values.RestoredDistributedTable(
834          self._container_strategy(), lambda: next_creator(*args, **kwargs)))  # pylint: disable=protected-access
835
836    return [
837        ops.resource_creator_scope("StaticHashTable", lookup_creator),
838        ops.resource_creator_scope("RestoredStaticHashTable",
839                                   restored_lookup_creator)
840    ]
841
842  def _assert_used_with_cluster_coordinator(self):
843    if (not self._used_with_coordinator and
844        not self._allow_run_without_coordinator):
845      raise NotImplementedError(
846          "`tf.distribute.experimental.ParameterServerStrategy` must be used "
847          "with `tf.distribute.experimental.coordinator.ClusterCoordinator` in "
848          "a custom training loop. If you are using `Model.fit`, please supply "
849          "a dataset function directly to a "
850          "`tf.keras.utils.experimental.DatasetCreator` instead.")
851
852  def _assert_being_scheduled_by_cluster_coordinator(self):
853    if not self._being_scheduled and not self._allow_run_without_coordinator:
854      logging.warning(
855          "A `tf.distribute.experimental.ParameterServerStrategy` method is "
856          "invoked without using `ClusterCoordinator.schedule`. If you are not "
857          "tracing a tf.function, this method is possibly executed on the "
858          "coordinator, which can be slow. To properly dispatch functions to "
859          "run on workers, methods like `run` or `reduce` should be used "
860          "within a function passed to `tf.distribute.experimental.coordinator."
861          "ClusterCoordinator.schedule`.")
862
863  # options is not used right now. But we may want to support options while
864  # creating InputWorkers in future, similar to MirroredStrategy.
865  def _input_workers_with_options(self, options=None):
866    input_workers_devices = (("/device:CPU:0", self.worker_devices),)
867    return input_lib.InputWorkers(
868        input_workers_devices, canonicalize_devices=False)
869
870  def _experimental_distribute_dataset(self, dataset, options):
871    input_workers_devices = self._input_workers_with_options()
872
873    # If this DistributedDataset is created outside ClusterCoordinator, i,e,
874    # outside a tf.function, we don't build its underlying datasets immediately
875    # until it is passed to ClusterCoordinator.create_per_worker_dataset.
876    return input_util.get_distributed_dataset(
877        dataset,
878        input_workers_devices,
879        self._container_strategy(),
880        num_replicas_in_sync=self._num_replicas_in_sync,
881        options=options,
882        build=ops.inside_function())  # will be built by ClusterCoordinator
883
884  def _distribute_datasets_from_function(self, dataset_fn, options):
885    # There is no synchronization beyond a worker and thus, the number of
886    # input pipelines in sync is only 1 per worker.
887    input_pipeline_id_in_sync = 0
888    num_input_pipelines_in_sync = 1
889
890    input_context = distribute_lib.InputContext(
891        num_input_pipelines=num_input_pipelines_in_sync,
892        input_pipeline_id=input_pipeline_id_in_sync,
893        num_replicas_in_sync=self._num_replicas_in_sync)
894
895    # If this DistributedDatasetFromFunction is created outside
896    # ClusterCoordinator, i,e, outside a tf.function, we don't build its
897    # underlying datasets immediately until it is passed to
898    # ClusterCoordinator.create_per_worker_dataset.
899    return input_util.get_distributed_datasets_from_function(
900        dataset_fn,
901        self._input_workers_with_options(options), [input_context],
902        self._container_strategy(),
903        options=options,
904        build=ops.inside_function())  # will be built by ClusterCoordinator
905
906  @property
907  def worker_devices(self):
908    num_gpus = self._num_gpus_per_worker
909    if num_gpus > 0:
910      compute_devices = tuple("/device:GPU:%d" % (i,) for i in range(num_gpus))
911    else:
912      compute_devices = ("/device:CPU:0",)
913    return compute_devices
914
915  def _call_for_each_replica(self, fn, args, kwargs):
916    self._assert_being_scheduled_by_cluster_coordinator()
917
918    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
919                                              args, kwargs)
920
921  def _reduce(self, reduce_op, value):
922    self._assert_being_scheduled_by_cluster_coordinator()
923    dst = device_util.current() or self._default_device or "/device:CPU:0"
924    destinations = device_util.canonicalize_without_job_and_task(dst)
925    result = self._local_results(
926        self.reduce_to(reduce_op, value, destinations))[0]
927    return result
928
929  def _reduce_to(self, reduce_op, value, destinations, options):
930    self._assert_being_scheduled_by_cluster_coordinator()
931
932    def get_values(x):
933      if isinstance(x, values.DistributedValues):
934        return self._cross_device_ops.reduce(
935            reduce_op, x, destinations=destinations)  # pylint: disable=protected-access
936      return x
937
938    return nest.map_structure(get_values, value)
939
940
941# The warning that will be logged if the way we initialize sharded variables
942# is memory-inefficient.
943_INEFFICIENT_INIT_WARNING = (
944    "Large variable %s is partitioned but not initialized in a "
945    "memory-efficient way. On each shard, the full value is first being "
946    "created and then sliced into smaller values. To reduce the memory "
947    "footprint, explicitly specify `dtype` and `shape` when creating "
948    "variables, and use `tf.initializers` to initialize the variable. "
949    "Note that some initializers (e.g., orthogonal) don't support "
950    "memory-efficient initialization and there is not much you can do here.")
951
952_LARGE_VARIABLE_NUM_ELEMENTS = 1e9
953