xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/central_storage_strategy.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing a single machine parameter server strategy."""
16
17from tensorflow.python.distribute import device_util
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.distribute import parameter_server_strategy
20from tensorflow.python.util.tf_export import tf_export
21
22
23@tf_export('distribute.experimental.CentralStorageStrategy', v1=[])
24class CentralStorageStrategy(distribute_lib.Strategy):
25  """A one-machine strategy that puts all variables on a single device.
26
27  Variables are assigned to local CPU or the only GPU. If there is more
28  than one GPU, compute operations (other than variable update operations)
29  will be replicated across all GPUs.
30
31  For Example:
32  ```
33  strategy = tf.distribute.experimental.CentralStorageStrategy()
34  # Create a dataset
35  ds = tf.data.Dataset.range(5).batch(2)
36  # Distribute that dataset
37  dist_dataset = strategy.experimental_distribute_dataset(ds)
38
39  with strategy.scope():
40    @tf.function
41    def train_step(val):
42      return val + 1
43
44    # Iterate over the distributed dataset
45    for x in dist_dataset:
46      # process dataset elements
47      strategy.run(train_step, args=(x,))
48  ```
49  """
50
51  def __init__(self, compute_devices=None, parameter_device=None):
52    extended = parameter_server_strategy.ParameterServerStrategyExtended(
53        self,
54        compute_devices=compute_devices,
55        parameter_device=parameter_device)
56    """Initializes the strategy with optional device strings.
57
58    Args:
59    compute_devices: an optional list of strings for device to replicate models
60      on. If this is not provided, all local GPUs will be used; if there is no
61      GPU, local CPU will be used.
62    parameter_device: an optional device string for which device to put
63      variables on. The default one is CPU or GPU if there is only one.
64    """
65    super(CentralStorageStrategy, self).__init__(extended)
66    distribute_lib.distribution_strategy_gauge.get_cell('V2').set(
67        'CentralStorageStrategy')
68
69  @classmethod
70  def _from_num_gpus(cls, num_gpus):
71    return cls(device_util.local_devices_from_num_gpus(num_gpus))
72
73  def experimental_distribute_dataset(self, dataset, options=None):  # pylint: disable=useless-super-delegation
74    """Distributes a tf.data.Dataset instance provided via dataset.
75
76    The returned dataset is a wrapped strategy dataset which creates a
77    multidevice iterator under the hood. It prefetches the input data to the
78    specified devices on the worker. The returned distributed dataset can be
79    iterated over similar to how regular datasets can.
80
81    NOTE: Currently, the user cannot add any more transformations to a
82    distributed dataset.
83
84    For Example:
85    ```
86    strategy = tf.distribute.CentralStorageStrategy()  # with 1 CPU and 1 GPU
87    dataset = tf.data.Dataset.range(10).batch(2)
88    dist_dataset = strategy.experimental_distribute_dataset(dataset)
89    for x in dist_dataset:
90      print(x)  # Prints PerReplica values [0, 1], [2, 3],...
91
92    ```
93    Args:
94      dataset: `tf.data.Dataset` to be prefetched to device.
95      options: `tf.distribute.InputOptions` used to control options on how this
96        dataset is distributed.
97
98    Returns:
99      A "distributed `Dataset`" that the caller can iterate over.
100    """
101    if (options and options.experimental_replication_moden ==
102        distribute_lib.InputReplicationMode.PER_REPLICA):
103      raise NotImplementedError(
104          'InputReplicationMode.PER_REPLICA '
105          'is only supported in '
106          '`experimental_distribute_datasets_from_function`.'
107      )
108    return super(CentralStorageStrategy, self).experimental_distribute_dataset(
109        dataset, options)
110
111  def experimental_local_results(self, value):  # pylint: disable=useless-super-delegation
112    """Returns the list of all local per-replica values contained in `value`.
113
114    In `CentralStorageStrategy` there is a single worker so the value returned
115    will be all the values on that worker.
116
117    Args:
118      value: A value returned by `run()`, `extended.call_for_each_replica()`,
119      or a variable created in `scope`.
120
121    Returns:
122      A tuple of values contained in `value`. If `value` represents a single
123      value, this returns `(value,).`
124    """
125    return super(CentralStorageStrategy, self).experimental_local_results(value)
126
127  def run(self, fn, args=(), kwargs=None, options=None):  # pylint: disable=useless-super-delegation
128    """Run `fn` on each replica, with the given arguments.
129
130    In `CentralStorageStrategy`, `fn` is  called on each of the compute
131    replicas, with the provided "per replica" arguments specific to that device.
132
133    Args:
134      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
135      args: (Optional) Positional arguments to `fn`.
136      kwargs: (Optional) Keyword arguments to `fn`.
137      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
138        the options to run `fn`.
139
140    Returns:
141      Return value from running `fn`.
142    """
143    return super(CentralStorageStrategy, self).run(fn, args, kwargs, options)
144
145  def reduce(self, reduce_op, value, axis):  # pylint: disable=useless-super-delegation
146    """Reduce `value` across replicas.
147
148    Given a per-replica value returned by `run`, say a
149    per-example loss, the batch will be divided across all the replicas. This
150    function allows you to aggregate across replicas and optionally also across
151    batch elements.  For example, if you have a global batch size of 8 and 2
152    replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
153    `[4, 5, 6, 7]` will be on replica 1. By default, `reduce` will just
154    aggregate across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful
155    when each replica is computing a scalar or some other value that doesn't
156    have a "batch" dimension (like a gradient). More often you will want to
157    aggregate across the global batch, which you can get by specifying the batch
158    dimension as the `axis`, typically `axis=0`. In this case it would return a
159    scalar `0+1+2+3+4+5+6+7`.
160
161    If there is a last partial batch, you will need to specify an axis so
162    that the resulting shape is consistent across replicas. So if the last
163    batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
164    would get a shape mismatch unless you specify `axis=0`. If you specify
165    `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
166    denominator of 6. Contrast this with computing `reduce_mean` to get a
167    scalar value on each replica and this function to average those means,
168    which will weigh some values `1/8` and others `1/4`.
169
170    For Example:
171    ```
172    strategy = tf.distribute.experimental.CentralStorageStrategy(
173        compute_devices=['CPU:0', 'GPU:0'], parameter_device='CPU:0')
174    ds = tf.data.Dataset.range(10)
175    # Distribute that dataset
176    dist_dataset = strategy.experimental_distribute_dataset(ds)
177
178    with strategy.scope():
179      @tf.function
180      def train_step(val):
181        # pass through
182        return val
183
184      # Iterate over the distributed dataset
185      for x in dist_dataset:
186        result = strategy.run(train_step, args=(x,))
187
188    result = strategy.reduce(tf.distribute.ReduceOp.SUM, result,
189                             axis=None).numpy()
190    # result: array([ 4,  6,  8, 10])
191
192    result = strategy.reduce(tf.distribute.ReduceOp.SUM, result, axis=0).numpy()
193    # result: 28
194    ```
195
196    Args:
197      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
198        be combined.
199      value: A "per replica" value, e.g. returned by `run` to
200        be combined into a single tensor.
201      axis: Specifies the dimension to reduce along within each
202        replica's tensor. Should typically be set to the batch dimension, or
203        `None` to only reduce across replicas (e.g. if the tensor has no batch
204        dimension).
205
206    Returns:
207      A `Tensor`.
208    """
209    return super(CentralStorageStrategy, self).reduce(reduce_op, value, axis)
210
211
212@tf_export(v1=['distribute.experimental.CentralStorageStrategy'])  # pylint: disable=missing-docstring
213class CentralStorageStrategyV1(distribute_lib.StrategyV1):
214
215  __doc__ = CentralStorageStrategy.__doc__
216
217  def __init__(self, compute_devices=None, parameter_device=None):
218    super(CentralStorageStrategyV1, self).__init__(
219        parameter_server_strategy.ParameterServerStrategyExtended(
220            self,
221            compute_devices=compute_devices,
222            parameter_device=parameter_device))
223    distribute_lib.distribution_strategy_gauge.get_cell('V1').set(
224        'CentralStorageStrategy')
225
226  __init__.__doc__ = CentralStorageStrategy.__init__.__doc__
227