xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/one_device_strategy.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A tf.distribute.Strategy for running on a single device."""
16
17from tensorflow.python.distribute import device_util
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.distribute import distribute_utils
20from tensorflow.python.distribute import input_lib
21from tensorflow.python.distribute import input_util
22from tensorflow.python.distribute import numpy_dataset
23from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.util import nest
29from tensorflow.python.util.tf_export import tf_export
30
31
32# TODO(josh11b): Do we wrap values in types to generate errors if you are
33# doing something that won't work with other DistributionStrategy
34# implementations?
35
36
37@tf_export("distribute.OneDeviceStrategy", v1=[])
38class OneDeviceStrategy(distribute_lib.Strategy):
39  """A distribution strategy for running on a single device.
40
41  Using this strategy will place any variables created in its scope on the
42  specified device. Input distributed through this strategy will be
43  prefetched to the specified device. Moreover, any functions called via
44  `strategy.run` will also be placed on the specified device
45  as well.
46
47  Typical usage of this strategy could be testing your code with the
48  tf.distribute.Strategy API before switching to other strategies which
49  actually distribute to multiple devices/machines.
50
51  For example:
52  ```
53  strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
54
55  with strategy.scope():
56    v = tf.Variable(1.0)
57    print(v.device)  # /job:localhost/replica:0/task:0/device:GPU:0
58
59  def step_fn(x):
60    return x * 2
61
62  result = 0
63  for i in range(10):
64    result += strategy.run(step_fn, args=(i,))
65  print(result)  # 90
66  ```
67  """
68
69  def __init__(self, device):
70    """Creates a `OneDeviceStrategy`.
71
72    Args:
73      device: Device string identifier for the device on which the variables
74        should be placed. See class docs for more details on how the device is
75        used. Examples: "/cpu:0", "/gpu:0", "/device:CPU:0", "/device:GPU:0"
76    """
77    super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device))
78    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
79        "OneDeviceStrategy")
80
81  def experimental_distribute_dataset(self, dataset, options=None):  # pylint: disable=useless-super-delegation
82    """Distributes a tf.data.Dataset instance provided via dataset.
83
84    In this case, there is only one device, so this is only a thin wrapper
85    around the input dataset. It will, however, prefetch the input data to the
86    specified device. The returned distributed dataset can be iterated over
87    similar to how regular datasets can.
88
89    NOTE: Currently, the user cannot add any more transformations to a
90    distributed dataset.
91
92    Example:
93    ```
94    strategy = tf.distribute.OneDeviceStrategy()
95    dataset = tf.data.Dataset.range(10).batch(2)
96    dist_dataset = strategy.experimental_distribute_dataset(dataset)
97    for x in dist_dataset:
98      print(x)  # [0, 1], [2, 3],...
99    ```
100    Args:
101      dataset: `tf.data.Dataset` to be prefetched to device.
102      options: `tf.distribute.InputOptions` used to control options on how this
103        dataset is distributed.
104    Returns:
105      A "distributed `Dataset`" that the caller can iterate over.
106    """
107    return super(OneDeviceStrategy, self).experimental_distribute_dataset(
108        dataset, options)
109
110  def distribute_datasets_from_function(
111      self,
112      dataset_fn,  # pylint: disable=useless-super-delegation
113      options=None):
114    """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
115
116    `dataset_fn` will be called once for each worker in the strategy. In this
117    case, we only have one worker and one device so `dataset_fn` is called
118    once.
119
120    The `dataset_fn` should take an `tf.distribute.InputContext` instance where
121    information about batching and input replication can be accessed:
122
123    ```
124    def dataset_fn(input_context):
125      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
126      d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
127      return d.shard(
128          input_context.num_input_pipelines, input_context.input_pipeline_id)
129
130    inputs = strategy.distribute_datasets_from_function(dataset_fn)
131
132    for batch in inputs:
133      replica_results = strategy.run(replica_fn, args=(batch,))
134    ```
135
136    IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
137    per-replica batch size, unlike `experimental_distribute_dataset`, which uses
138    the global batch size.  This may be computed using
139    `input_context.get_per_replica_batch_size`.
140
141    Args:
142      dataset_fn: A function taking a `tf.distribute.InputContext` instance and
143        returning a `tf.data.Dataset`.
144      options: `tf.distribute.InputOptions` used to control options on how this
145        dataset is distributed.
146
147    Returns:
148      A "distributed `Dataset`", which the caller can iterate over like regular
149      datasets.
150    """
151    return super(OneDeviceStrategy,
152                 self).distribute_datasets_from_function(dataset_fn, options)
153
154  def experimental_local_results(self, value):  # pylint: disable=useless-super-delegation
155    """Returns the list of all local per-replica values contained in `value`.
156
157    In `OneDeviceStrategy`, the `value` is always expected to be a single
158    value, so the result is just the value in a tuple.
159
160    Args:
161      value: A value returned by `experimental_run()`, `run()`,
162        `extended.call_for_each_replica()`, or a variable created in `scope`.
163
164    Returns:
165      A tuple of values contained in `value`. If `value` represents a single
166      value, this returns `(value,).`
167    """
168    return super(OneDeviceStrategy, self).experimental_local_results(value)
169
170  def run(self, fn, args=(), kwargs=None, options=None):  # pylint: disable=useless-super-delegation
171    """Run `fn` on each replica, with the given arguments.
172
173    In `OneDeviceStrategy`, `fn` is simply called within a device scope for the
174    given device, with the provided arguments.
175
176    Args:
177      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
178      args: (Optional) Positional arguments to `fn`.
179      kwargs: (Optional) Keyword arguments to `fn`.
180      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
181        the options to run `fn`.
182
183    Returns:
184      Return value from running `fn`.
185    """
186    return super(OneDeviceStrategy, self).run(fn, args, kwargs, options)
187
188  def reduce(self, reduce_op, value, axis):  # pylint: disable=useless-super-delegation
189    """Reduce `value` across replicas.
190
191    In `OneDeviceStrategy`, there is only one replica, so if axis=None, value
192    is simply returned. If axis is specified as something other than None,
193    such as axis=0, value is reduced along that axis and returned.
194
195    Example:
196    ```
197    t = tf.range(10)
198
199    result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=None).numpy()
200    # result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
201
202    result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=0).numpy()
203    # result: 45
204    ```
205
206    Args:
207      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
208        be combined.
209      value: A "per replica" value, e.g. returned by `run` to
210        be combined into a single tensor.
211      axis: Specifies the dimension to reduce along within each
212        replica's tensor. Should typically be set to the batch dimension, or
213        `None` to only reduce across replicas (e.g. if the tensor has no batch
214        dimension).
215
216    Returns:
217      A `Tensor`.
218    """
219    return super(OneDeviceStrategy, self).reduce(reduce_op, value, axis)
220
221  def scope(self):  # pylint: disable=useless-super-delegation
222    """Returns a context manager selecting this Strategy as current.
223
224    Inside a `with strategy.scope():` code block, this thread
225    will use a variable creator set by `strategy`, and will
226    enter its "cross-replica context".
227
228    In `OneDeviceStrategy`, all variables created inside `strategy.scope()`
229    will be on `device` specified at strategy construction time.
230    See example in the docs for this class.
231
232    Returns:
233      A context manager to use for creating variables with this strategy.
234    """
235    return super(OneDeviceStrategy, self).scope()
236
237
238@tf_export(v1=["distribute.OneDeviceStrategy"])  # pylint: disable=empty-docstring
239class OneDeviceStrategyV1(distribute_lib.StrategyV1):
240
241  __doc__ = OneDeviceStrategy.__doc__.replace(
242      "For example:\n  ```",
243      "For example:\n  ```\n  tf.enable_eager_execution()")
244
245  def __init__(self, device):
246    super(OneDeviceStrategyV1, self).__init__(OneDeviceExtended(self, device))
247    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
248        "OneDeviceStrategy")
249  __init__.__doc__ = OneDeviceStrategy.__init__.__doc__
250
251
252# TODO(josh11b): Switch to V2 after callers have been updated to only V2 APIs.
253class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
254  """Implementation of OneDeviceStrategy."""
255
256  def __init__(self, container_strategy, device):
257    super(OneDeviceExtended, self).__init__(container_strategy)
258    self._device = device_util.resolve(device)
259    self._input_device = device_util.get_host_for_device(self._device)
260
261  def _input_workers_with_options(self, options=None):
262    if not options or options.experimental_fetch_to_device:
263      return input_lib.InputWorkers([(self._input_device, (self._device,))])
264    else:
265      return input_lib.InputWorkers([(self._input_device,
266                                      (self._input_device,))])
267
268  @property
269  def _input_workers(self):
270    return self._input_workers_with_options()
271
272  def _create_variable(self, next_creator, **kwargs):
273    colocate_with = kwargs.pop("colocate_with", None)
274    if colocate_with is None:
275      with ops.device(self._device):
276        return next_creator(**kwargs)
277    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
278      with ops.device(colocate_with.device):
279        return next_creator(**kwargs)
280    else:
281      with ops.colocate_with(colocate_with):
282        return next_creator(**kwargs)
283
284  def _validate_colocate_with_variable(self, colocate_with_variable):
285    distribute_utils.validate_colocate(colocate_with_variable, self)
286
287  def _make_dataset_iterator(self, dataset):
288    """Make iterator from dataset without splitting the batch."""
289    # Note that split_batch_by argument is not passed because it is always 1 in
290    # this strategy, and adding it adds unnecessary overhead to the dataset.
291    return input_lib_v1.DatasetIterator(dataset, self._input_workers,
292                                        self._container_strategy())
293
294  def _make_input_fn_iterator(
295      self,
296      input_fn,
297      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
298    return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
299                                              [distribute_lib.InputContext()],
300                                              self._container_strategy())
301
302  def _experimental_make_numpy_dataset(self, numpy_input, session):
303    return numpy_dataset.one_host_numpy_dataset(
304        numpy_input, numpy_dataset.SingleDevice(self._input_device), session)
305
306  def _broadcast_to(self, tensor, destinations):
307    del destinations
308    return tensor
309
310  def _experimental_distribute_dataset(self, dataset, options):
311    # Note that split_batch_by argument is not passed because it is always 1 in
312    # this strategy, and adding it adds unnecessary overhead to the dataset.
313    if (options and options.experimental_replication_mode ==
314        distribute_lib.InputReplicationMode.PER_REPLICA):
315      raise NotImplementedError(
316          "InputReplicationMode.PER_REPLICA "
317          "is only supported in  "
318          "`experimental_distribute_datasets_from_function`."
319      )
320    return input_util.get_distributed_dataset(
321        dataset,
322        self._input_workers_with_options(options),
323        self._container_strategy(),
324        options=options)
325
326  def _distribute_datasets_from_function(self, dataset_fn, options):
327    if (options and options.experimental_replication_mode ==
328        distribute_lib.InputReplicationMode.PER_REPLICA):
329      raise NotImplementedError(
330          "InputReplicationMode.PER_REPLICA "
331          "is only supported in "
332          "`experimental_distribute_datasets_from_function` "
333          "of tf.distribute.MirroredStrategy")
334    return input_util.get_distributed_datasets_from_function(
335        dataset_fn,
336        self._input_workers_with_options(options),
337        [distribute_lib.InputContext()],
338        self._container_strategy(),
339        options=options)
340
341  def _experimental_distribute_values_from_function(self, value_fn):
342    # TODO(b/137795644): This should return a PerReplica value but other
343    # methods like run in OneDeviceStrategy need to be modified
344    # to do the same.
345    return value_fn(distribute_lib.ValueContext())
346
347  # TODO(priyag): Deal with OutOfRange errors  once b/111349762 is fixed.
348  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
349                                          initial_loop_values=None):
350    if initial_loop_values is None:
351      initial_loop_values = {}
352    initial_loop_values = nest.flatten(initial_loop_values)
353
354    ctx = input_lib.MultiStepContext()
355    def body(i, *args):
356      """A wrapper around `fn` to create the while loop body."""
357      del args
358      fn_result = fn(ctx, iterator.get_next())
359      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
360      with ops.control_dependencies([fn_result]):
361        return [i + 1] + flat_last_step_outputs
362
363    # We capture the control_flow_context at this point, before we run `fn`
364    # inside a while_loop. This is useful in cases where we might need to exit
365    # these contexts and get back to the outer context to do some things, for
366    # e.g. create an op which should be evaluated only once at the end of the
367    # loop on the host. One such usage is in creating metrics' value op.
368    self._outer_control_flow_context = (
369        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
370
371    # TODO(priyag): Use max_iterations instead of an explicit counter.
372    cond = lambda i, *args: i < iterations
373    i = constant_op.constant(0)
374    loop_result = control_flow_ops.while_loop(
375        cond, body, [i] + initial_loop_values, name="",
376        parallel_iterations=1, back_prop=False, swap_memory=False,
377        return_same_structure=True)
378    del self._outer_control_flow_context
379
380    ctx.run_op = control_flow_ops.group(loop_result)
381
382    # Convert the last_step_outputs from a list to the original dict structure
383    # of last_step_outputs.
384    last_step_tensor_outputs = loop_result[1:]
385    last_step_tensor_outputs_dict = nest.pack_sequence_as(
386        ctx.last_step_outputs, last_step_tensor_outputs)
387
388    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
389    return ctx
390
391  def _call_for_each_replica(self, fn, args, kwargs):
392    strategy = self._container_strategy()
393    with ops.device(self._device), _OneDeviceReplicaContext(strategy):
394      return fn(*args, **kwargs)
395
396  def _reduce_to(self, reduce_op, value, destinations, options):
397    del reduce_op, destinations, options
398    return value
399
400  def _gather_to_implementation(self, value, destinations, axis, options):
401    del destinations, axis, options
402    return value
403
404  def _update(self, var, fn, args, kwargs, group):
405    # The implementations of _update() and _update_non_slot() are identical
406    # except _update() passes `var` as the first argument to `fn()`.
407    return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
408
409  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
410    del colocate_with
411    with ops.device(self._device), distribute_lib.UpdateContext(self._device):
412      result = fn(*args, **kwargs)
413      if group:
414        return result
415      else:
416        return nest.map_structure(self._local_results, result)
417
418  def read_var(self, replica_local_var):
419    """Read the aggregate value of a replica-local variable."""
420    return array_ops.identity(replica_local_var)
421
422  def _local_results(self, value):
423    return (value,)
424
425  def value_container(self, value):
426    return value
427
428  def _in_multi_worker_mode(self):
429    """Whether this strategy indicates working in multi-worker settings."""
430    return False
431
432  @property
433  def _num_replicas_in_sync(self):
434    return 1
435
436  @property
437  def worker_devices(self):
438    return (self._device,)
439
440  @property
441  def parameter_devices(self):
442    return (self._device,)
443
444  def non_slot_devices(self, var_list):
445    del var_list
446    return (self._device,)
447
448  @property
449  def experimental_should_init(self):
450    return True
451
452  @property
453  def experimental_between_graph(self):
454    return False
455
456  @property
457  def should_checkpoint(self):
458    return True
459
460  @property
461  def should_save_summary(self):
462    return True
463
464  # TODO(priyag): Delete this once all strategies use global batch size.
465  @property
466  def _global_batch_size(self):
467    """Global and per-replica batching are equivalent for OneDeviceStrategy."""
468    return True
469
470  @property
471  def _support_per_replica_values(self):
472    return False
473
474  def _get_local_replica_id(self, replica_id_in_sync_group):
475    return replica_id_in_sync_group
476
477
478class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
479  """ReplicaContext for OneDeviceStrategy."""
480
481  def __init__(self, strategy):
482    distribute_lib.ReplicaContext.__init__(
483        self, strategy, replica_id_in_sync_group=0)
484
485  @property
486  def devices(self):
487    return self._strategy.extended.worker_devices
488