xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/ops/multi_device_iterator_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from tensorflow.python.data.ops import dataset_ops
17from tensorflow.python.data.ops import iterator_ops
18from tensorflow.python.data.ops import options as options_lib
19from tensorflow.python.data.util import structure
20from tensorflow.python.eager import context
21from tensorflow.python.eager import function
22from tensorflow.python.framework import composite_tensor
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import type_spec
28from tensorflow.python.framework import type_utils
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import functional_ops
32from tensorflow.python.ops import gen_dataset_ops
33from tensorflow.python.ops import resource_variable_ops
34
35
36class _PerDeviceGenerator(dataset_ops.DatasetV2):
37  """A `dummy` generator dataset."""
38
39  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
40               source_device, element_spec, iterator_is_anonymous):
41    self._element_spec = element_spec
42
43    multi_device_iterator_string_handle = (
44        gen_dataset_ops.multi_device_iterator_to_string_handle(
45            multi_device_iterator_resource))
46
47    # TODO(b/124254153): Enable autograph once the overhead is low enough.
48    @function.defun(autograph=False)  # Pure graph code.
49    def _init_func():
50      return multi_device_iterator_string_handle
51
52    init_func_concrete = _init_func.get_concrete_function()
53
54    # TODO(b/124254153): Enable autograph once the overhead is low enough.
55    @function.defun(autograph=False)  # Pure graph code.
56    def _remote_init_func():
57      return functional_ops.remote_call(
58          target=source_device,
59          args=init_func_concrete.captured_inputs,
60          Tout=[dtypes.string],
61          f=init_func_concrete)
62
63    self._init_func = _remote_init_func.get_concrete_function()
64    self._init_captured_args = self._init_func.captured_inputs
65
66    # TODO(b/124254153): Enable autograph once the overhead is low enough.
67    @function.defun(
68        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
69        autograph=False)  # Pure graph code.
70    def _next_func(string_handle):
71      # pylint: disable=protected-access
72      multi_device_iterator = (
73          gen_dataset_ops.multi_device_iterator_from_string_handle(
74              string_handle=string_handle,
75              output_types=structure.get_flat_tensor_types(self._element_spec),
76              output_shapes=structure.get_flat_tensor_shapes(
77                  self._element_spec)))
78      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
79          multi_device_iterator=multi_device_iterator,
80          shard_num=shard_num,
81          incarnation_id=incarnation_id,
82          output_types=structure.get_flat_tensor_types(self._element_spec),
83          output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
84
85    next_func_concrete = _next_func.get_concrete_function()
86
87    # TODO(b/124254153): Enable autograph once the overhead is low enough.
88    @function.defun_with_attributes(
89        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
90        attributes={"experimental_ints_on_device": True},
91        autograph=False)  # Pure graph code.
92    def _remote_next_func(string_handle):
93      return_values = functional_ops.remote_call(
94          target=source_device,
95          args=[string_handle] + next_func_concrete.captured_inputs,
96          Tout=structure.get_flat_tensor_types(self._element_spec),
97          f=next_func_concrete)
98      # Add full type information to the graph so that the RemoteCall op
99      # can determine for each of its outputs whether or not they are ragged
100      # tensors (or other types that use variants) that contain strings
101      # (or other host memory types). Then RemoteCall can
102      # appropriately set AllocatorAttributes to control copies so
103      # strings/host memory types stay on CPU.
104      fulltype_list = type_utils.fulltypes_for_flat_tensors(self._element_spec)
105      fulltype = type_utils.fulltype_list_to_product(fulltype_list)
106      for return_value in return_values:
107        return_value.op.experimental_set_type(fulltype)
108      return return_values
109
110    self._next_func = _remote_next_func.get_concrete_function()
111    self._next_captured_args = self._next_func.captured_inputs
112
113    if iterator_is_anonymous:
114      self._next_captured_args = self._next_captured_args + [
115          multi_device_iterator_resource
116      ]
117
118    self._incarnation_id_index = -1
119    for i, arg in enumerate(self._next_captured_args):
120      if arg is incarnation_id:
121        self._incarnation_id_index = i
122
123    # TODO(b/124254153): Enable autograph once the overhead is low enough.
124    @function.defun(
125        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
126        autograph=False)  # Pure graph code.
127    def _finalize_func(unused_string_handle):
128      return array_ops.constant(0, dtypes.int64)
129
130    finalize_func_concrete = _finalize_func.get_concrete_function()
131
132    # TODO(b/124254153): Enable autograph once the overhead is low enough.
133    @function.defun(
134        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
135        autograph=False)  # Pure graph code.
136    def _remote_finalize_func(string_handle):
137      return functional_ops.remote_call(
138          target=source_device,
139          args=[string_handle] + finalize_func_concrete.captured_inputs,
140          Tout=[dtypes.int64],
141          f=finalize_func_concrete)
142
143    self._finalize_func = _remote_finalize_func.get_concrete_function()
144    self._finalize_captured_args = self._finalize_func.captured_inputs
145
146    variant_tensor = gen_dataset_ops.generator_dataset(
147        self._init_captured_args,
148        self._next_captured_args,
149        self._finalize_captured_args,
150        init_func=self._init_func,
151        next_func=self._next_func,
152        finalize_func=self._finalize_func,
153        **self._flat_structure)
154    super(_PerDeviceGenerator, self).__init__(variant_tensor)
155
156  def _inputs(self):
157    # TODO(b/116506223): Determine which datasets should be used as inputs here.
158    return []
159
160  @property
161  def element_spec(self):
162    return self._element_spec
163
164
165class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
166  """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id.
167
168  Re-uses the functions from the provided per_device_dataset and just switches
169  out the function argument corresponding to the incarnation_id.
170  """
171
172  def __init__(self, per_device_dataset, incarnation_id):
173    # pylint: disable=protected-access
174    self._element_spec = per_device_dataset.element_spec
175    self._init_func = per_device_dataset._init_func
176    self._init_captured_args = self._init_func.captured_inputs
177
178    self._next_func = per_device_dataset._next_func
179    self._next_captured_args = per_device_dataset._next_captured_args
180    # The captured arguments to the next_func are string_handle, incarnation_id.
181    # We update the incarnation id to the new one.
182    self._next_captured_args[
183        per_device_dataset._incarnation_id_index] = incarnation_id
184
185    self._finalize_func = per_device_dataset._finalize_func
186    self._finalize_captured_args = per_device_dataset._finalize_captured_args
187
188    variant_tensor = gen_dataset_ops.generator_dataset(
189        self._init_captured_args,
190        self._next_captured_args,
191        self._finalize_captured_args,
192        init_func=self._init_func,
193        next_func=self._next_func,
194        finalize_func=self._finalize_func,
195        **self._flat_structure)
196    super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
197
198  def _inputs(self):
199    # TODO(b/116506223): Determine which datasets should be used as inputs here.
200    return []
201
202  @property
203  def element_spec(self):
204    return self._element_spec
205
206
207def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size,
208                           experimental_slack):
209  """Uses _prototype_device_datasets[i] to build a dataset for the device."""
210  ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id)
211  if prefetch_buffer_size > 0:
212    if experimental_slack:
213      ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1)
214    else:
215      ds = ds.prefetch(prefetch_buffer_size)
216  return ds
217
218
219class MultiDeviceIterator:
220  """An iterator over multiple devices."""
221
222  def __init__(self,
223               dataset,
224               devices,
225               max_buffer_size=1,
226               prefetch_buffer_size=1,
227               source_device="/cpu:0"):
228    """Constructs a MultiDeviceIterator.
229
230    Args:
231      dataset: The input dataset to be iterated over.
232      devices: The list of devices to fetch data to.
233      max_buffer_size: Maximum size of the host side per device buffer to keep.
234      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
235        prefetch into.
236      source_device: The host device to place the `dataset` on.  In order to
237        prevent deadlocks, if the prefetch_buffer_size is greater than the
238        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
239    """
240    options = options_lib.Options()
241    options.experimental_distribute.num_devices = len(devices)
242    dataset = dataset.with_options(options)
243    self._dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
244    self._experimental_slack = dataset.options().experimental_slack
245    self._devices = devices
246    self._source_device = source_device
247    self._source_device_tensor = ops.convert_to_tensor(source_device)
248    self._max_buffer_size = max_buffer_size
249    self._prefetch_buffer_size = prefetch_buffer_size
250
251    if self._prefetch_buffer_size > self._max_buffer_size:
252      self._max_buffer_size = self._prefetch_buffer_size
253
254    # Create the MultiDeviceIterator.
255    with ops.device(self._source_device):
256      # TODO(b/121378567): Get rid of this shared_name hack.
257      shared_name = ""
258      if context.executing_eagerly():
259        shared_name = context.anonymous_name()
260      self._multi_device_iterator_resource = (
261          gen_dataset_ops.multi_device_iterator(
262              devices=self._devices,
263              shared_name=shared_name,
264              container="",
265              **self._dataset._flat_structure))  # pylint: disable=protected-access
266      if context.executing_eagerly():
267        # Delete the resource when this object is deleted
268        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
269            handle=self._multi_device_iterator_resource,
270            handle_device=self._source_device)
271
272      # The incarnation ID is used to ensure consistency between the per-device
273      # iterators and the multi-device iterator.
274      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
275          self._dataset._variant_tensor,  # pylint: disable=protected-access
276          self._multi_device_iterator_resource,
277          max_buffer_size=self._max_buffer_size)
278
279    self._prototype_device_datasets = []
280    for i, device in enumerate(self._devices):
281      with ops.device(device):
282        ds = _PerDeviceGenerator(
283            i,
284            self._multi_device_iterator_resource,
285            self._incarnation_id,
286            self._source_device_tensor,
287            self._dataset.element_spec,
288            iterator_is_anonymous=False)
289        self._prototype_device_datasets.append(ds)
290
291    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
292    # initialize the device side of the pipeline. This would allow the
293    # MultiDeviceIterator to choose, for example, to move some transformations
294    # into the device side from its input. It might be useful in rewriting.
295    # Create the per device iterators.
296    self._device_iterators = []
297    for i, device in enumerate(self._devices):
298      with ops.device(device):
299        ds = _create_device_dataset(self._prototype_device_datasets[i],
300                                    self._incarnation_id,
301                                    self._prefetch_buffer_size,
302                                    self._experimental_slack)
303        if context.executing_eagerly():
304          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
305        else:
306          self._device_iterators.append(
307              dataset_ops.make_initializable_iterator(ds))
308
309    if not context.executing_eagerly():
310      device_iterator_initializers = [
311          iterator.initializer for iterator in self._device_iterators
312      ]
313      self._initializer = control_flow_ops.group(*device_iterator_initializers)
314
315  def get_next(self, device=None):
316    """Returns the next element given a `device`, else returns all in a list."""
317    if device is not None:
318      index = self._devices.index(device)
319      return self._device_iterators[index].get_next()
320
321    result = []
322    for i, device in enumerate(self._devices):
323      with ops.device(device):
324        result.append(self._device_iterators[i].get_next())
325    return result
326
327  def get_next_as_optional(self):
328    result = []
329    for i, device in enumerate(self._devices):
330      with ops.device(device):
331        result.append(self._device_iterators[i].get_next_as_optional())
332    return result
333
334  @property
335  def initializer(self):
336    if context.executing_eagerly():
337      return control_flow_ops.no_op()
338    return self._initializer
339
340  def _eager_reset(self):
341    """Resets the MultiDeviceIterator in eager mode."""
342    if not ops.executing_eagerly_outside_functions():
343      raise ValueError(
344          "Resetting a multi-device iterator is only supported in the eager "
345          "mode.")
346    # pylint: disable=protected-access
347    self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
348        self._dataset._variant_tensor,
349        self._multi_device_iterator_resource,
350        max_buffer_size=self._max_buffer_size)
351    for i, device in enumerate(self._devices):
352      with ops.device(device):
353        ds = _create_device_dataset(self._prototype_device_datasets[i],
354                                    self._incarnation_id,
355                                    self._prefetch_buffer_size,
356                                    self._experimental_slack)
357        # Reset the device iterator resources with the new dataset.
358        ds_variant = ds._variant_tensor
359        gen_dataset_ops.make_iterator(
360            ds_variant, self._device_iterators[i]._iterator_resource)
361
362  @property
363  def element_spec(self):
364    return self._dataset.element_spec
365
366
367class MultiDeviceIteratorSpec(type_spec.TypeSpec):
368  """Type specification for `OwnedMultiDeviceIterator`."""
369
370  __slots__ = ["_devices", "_source_device", "_element_spec"]
371
372  def __init__(self, devices, source_device, element_spec):
373    self._devices = devices
374    self._source_device = source_device
375    self._element_spec = element_spec
376
377  @property
378  def value_type(self):
379    return OwnedMultiDeviceIterator
380
381  def _serialize(self):
382    return (tuple(self._devices), self._source_device, self._element_spec)
383
384  @property
385  def _component_specs(self):
386    specs = [
387        tensor_spec.TensorSpec([], dtypes.resource),
388    ]
389    for _ in range(len(self._devices)):
390      specs.append(iterator_ops.IteratorSpec(self._element_spec))
391    return specs
392
393  def _to_components(self, value):
394    # pylint: disable=protected-access
395    c = [value._multi_device_iterator_resource]
396    c.extend(value._device_iterators)
397    return c
398
399  def _from_components(self, components):
400    return OwnedMultiDeviceIterator(
401        dataset=None,
402        devices=self._devices,
403        source_device=self._source_device,
404        components=components,
405        element_spec=self._element_spec)
406
407  @staticmethod
408  def from_value(value):
409    # pylint: disable=protected-access
410    return MultiDeviceIteratorSpec(
411        value._devices,
412        value._source_device,
413        value.element_spec)
414
415
416class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor):
417  """An iterator over multiple devices.
418
419  The multi-device iterator resource created through `OwnedMultiDeviceIterator`
420  is owned by the Python object and the life time of the underlying resource is
421  tied to the life time of the `OwnedMultiDeviceIterator` object. This makes
422  `OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of
423  tf.functions.
424  """
425
426  def __init__(self,
427               dataset=None,
428               devices=None,
429               max_buffer_size=1,
430               prefetch_buffer_size=1,
431               source_device="/cpu:0",
432               components=None,
433               element_spec=None):
434    """Constructs an owned MultiDeviceIterator object.
435
436    Args:
437      dataset: The input dataset to be iterated over.
438      devices: (Required.) The list of devices to fetch data to.
439      max_buffer_size: Maximum size of the host side per device buffer to keep.
440      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
441        prefetch into.
442      source_device: The host device to place the `dataset` on.  In order to
443        prevent deadlocks, if the prefetch_buffer_size is greater than the
444        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
445      components: Tensor components to construct the MultiDeviceIterator from.
446      element_spec: A (nested) structure of `tf.TypeSpec` objects that
447        represents the type specification of elements of the iterator.
448
449    Raises:
450      RuntimeError: If executed in graph mode or outside of function building
451        mode.
452      ValueError: If any of the following happens:
453        - `devices` is `None`
454        - `dataset` is `None` and either `components` or `element_spec` is
455          `None`
456        - `dataset` is not None and either `components` or `element_spec` is
457          provided
458    """
459    if not context.executing_eagerly() and not ops.inside_function():
460      raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of "
461                         "tf.function or when eager execution is enabled.")
462    if devices is None:
463      raise ValueError("`devices` must be provided.")
464
465    if dataset is None:
466      if (components is None or element_spec is None):
467        raise ValueError(
468            "When `dataset` is not provided, both `components` and "
469            "`element_spec` must be specified.")
470      self._element_spec = element_spec
471      self._devices = devices
472      self._source_device = source_device
473      self._multi_device_iterator_resource = components[0]
474      self._device_iterators = components[1:]
475    else:
476      if (components is not None or element_spec is not None):
477        raise ValueError(
478            "When `dataset` is provided, `element_spec` and `components` must "
479            "not be specified.")
480      options = options_lib.Options()
481      options.experimental_distribute.num_devices = len(devices)
482      dataset = dataset.with_options(options)
483      dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
484      self._element_spec = dataset.element_spec
485      experimental_slack = dataset.options().experimental_slack
486      self._devices = devices
487      self._source_device = source_device
488      source_device_tensor = ops.convert_to_tensor(self._source_device)
489
490      if prefetch_buffer_size > max_buffer_size:
491        max_buffer_size = prefetch_buffer_size
492
493      # Create the MultiDeviceIterator.
494      with ops.device(self._source_device):
495        self._multi_device_iterator_resource = (
496            gen_dataset_ops.anonymous_multi_device_iterator_v3(
497                devices=self._devices, **dataset._flat_structure))  # pylint: disable=protected-access
498
499        # The incarnation ID is used to ensure consistency between the
500        # per-device iterators and the multi-device iterator.
501        incarnation_id = gen_dataset_ops.multi_device_iterator_init(
502            dataset._variant_tensor,  # pylint: disable=protected-access
503            self._multi_device_iterator_resource,
504            max_buffer_size=max_buffer_size)
505
506      prototype_device_datasets = []
507      for i, device in enumerate(self._devices):
508        with ops.device(device):
509          ds = _PerDeviceGenerator(
510              i,
511              self._multi_device_iterator_resource,
512              incarnation_id,
513              source_device_tensor,
514              dataset.element_spec,
515              iterator_is_anonymous=True,
516          )
517          prototype_device_datasets.append(ds)
518
519      # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
520      # initialize the device side of the pipeline. This would allow the
521      # MultiDeviceIterator to choose, for example, to move some transformations
522      # into the device side from its input. It might be useful in rewriting.
523      # Create the per device iterators.
524      self._device_iterators = []
525
526      for i, device in enumerate(self._devices):
527        with ops.device(device):
528          ds = _create_device_dataset(prototype_device_datasets[i],
529                                      incarnation_id, prefetch_buffer_size,
530                                      experimental_slack)
531          iterator = iter(ds)
532          self._device_iterators.append(iterator)
533
534  def get_next(self, device=None):
535    """Returns the next element given a `device`, else returns all in a list."""
536    if device is not None:
537      index = self._devices.index(device)
538      return self._device_iterators[index].get_next()
539
540    result = []
541    for i, device in enumerate(self._devices):
542      with ops.device(device):
543        result.append(self._device_iterators[i].get_next())
544    return result
545
546  def __iter__(self):
547    return self
548
549  def next(self):
550    return self.__next__()
551
552  def __next__(self):
553    try:
554      return self.get_next()
555    except errors.OutOfRangeError:
556      raise StopIteration
557
558  def get_next_as_optional(self):
559    result = []
560    for i, device in enumerate(self._devices):
561      with ops.device(device):
562        result.append(self._device_iterators[i].get_next_as_optional())
563    return result
564
565  @property
566  def element_spec(self):
567    return self._element_spec
568
569  @property
570  def _type_spec(self):
571    return MultiDeviceIteratorSpec(self._devices, self._source_device,
572                                   self._element_spec)
573