xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/prefetching_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from tensorflow.python.data.ops import dataset_ops
17from tensorflow.python.data.ops import iterator_ops
18from tensorflow.python.data.ops import structured_function
19from tensorflow.python.data.util import structure
20from tensorflow.python.eager import function
21from tensorflow.python.framework import device as framework_device
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import functional_ops
27from tensorflow.python.ops import gen_dataset_ops
28from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.util.tf_export import tf_export
31
32
33@tf_export("data.experimental.prefetch_to_device")
34def prefetch_to_device(device, buffer_size=None):
35  """A transformation that prefetches dataset values to the given `device`.
36
37  NOTE: Although the transformation creates a `tf.data.Dataset`, the
38  transformation must be the final `Dataset` in the input pipeline.
39
40  For example,
41  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
42  >>> dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/cpu:0"))
43  >>> for element in dataset:
44  ...   print(f'Tensor {element} is on device {element.device}')
45  Tensor 1 is on device /job:localhost/replica:0/task:0/device:CPU:0
46  Tensor 2 is on device /job:localhost/replica:0/task:0/device:CPU:0
47  Tensor 3 is on device /job:localhost/replica:0/task:0/device:CPU:0
48
49  Args:
50    device: A string. The name of a device to which elements will be prefetched.
51    buffer_size: (Optional.) The number of elements to buffer on `device`.
52      Defaults to an automatically chosen value.
53
54  Returns:
55    A `Dataset` transformation function, which can be passed to
56    `tf.data.Dataset.apply`.
57  """
58  def _apply_fn(dataset):
59    return dataset.apply(
60        copy_to_device(target_device=device)).prefetch(buffer_size)
61
62  return _apply_fn
63
64
65@tf_export("data.experimental.copy_to_device")
66def copy_to_device(target_device, source_device="/cpu:0"):
67  """A transformation that copies dataset elements to the given `target_device`.
68
69  Args:
70    target_device: The name of a device to which elements will be copied.
71    source_device: The original device on which `input_dataset` will be placed.
72
73  Returns:
74    A `Dataset` transformation function, which can be passed to
75    `tf.data.Dataset.apply`.
76  """
77
78  def _apply_fn(dataset):
79    return _CopyToDeviceDataset(
80        dataset, target_device=target_device, source_device=source_device)
81
82  return _apply_fn
83
84
85# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
86# all inputs to the Op are in host memory, thereby avoiding some unnecessary
87# Sends and Recvs.
88class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
89  """A `Dataset` that copies elements to another device."""
90
91  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
92    """Constructs a _CopyToDeviceDataset.
93
94    Args:
95      input_dataset: `Dataset` to be copied
96      target_device: The name of the device to which elements would be copied.
97      source_device: Device where input_dataset would be placed.
98    """
99    self._input_dataset = input_dataset._apply_debug_options()  # pylint: disable=protected-access
100    self._target_device = target_device
101    spec = framework_device.DeviceSpec().from_string(self._target_device)
102    self._is_gpu_target = (spec.device_type == "GPU")
103    self._source_device_string = source_device
104    self._source_device = ops.convert_to_tensor(source_device)
105
106    wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
107        self._input_dataset._variant_tensor)  # pylint: disable=protected-access
108
109    @function.defun()
110    def _init_func():
111      """Creates an iterator for the input dataset.
112
113      Returns:
114        A `string` tensor that encapsulates the iterator created.
115      """
116      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
117      resource = gen_dataset_ops.anonymous_iterator(
118          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
119      with ops.control_dependencies(
120          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
121        return gen_dataset_ops.iterator_to_string_handle(resource)
122
123    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access
124
125    @function.defun()
126    def _remote_init_func():
127      return functional_ops.remote_call(
128          target=self._source_device,
129          args=init_func_concrete.captured_inputs,
130          Tout=[dtypes.string],
131          f=init_func_concrete)
132
133    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
134    self._init_captured_args = self._init_func.captured_inputs
135
136    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
137    def _next_func(string_handle):
138      """Calls get_next for created iterator.
139
140      Args:
141        string_handle: An iterator string handle created by _init_func
142      Returns:
143        The elements generated from `input_dataset`
144      """
145      with ops.device(self._source_device_string):
146        iterator = iterator_ops.Iterator.from_string_handle(
147            string_handle,
148            dataset_ops.get_legacy_output_types(self),
149            dataset_ops.get_legacy_output_shapes(self),
150            dataset_ops.get_legacy_output_classes(self))
151      return structure.to_tensor_list(self.element_spec, iterator.get_next())
152
153    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access
154
155    @function.defun_with_attributes(
156        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
157        attributes={"experimental_ints_on_device": True})
158    def _remote_next_func(string_handle):
159      return functional_ops.remote_call(
160          target=self._source_device,
161          args=[string_handle] + next_func_concrete.captured_inputs,
162          Tout=self._input_dataset._flat_types,  # pylint: disable=protected-access
163          f=next_func_concrete)
164
165    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
166    self._next_captured_args = self._next_func.captured_inputs
167
168    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
169    def _finalize_func(string_handle):
170      """Destroys the iterator resource created.
171
172      Args:
173        string_handle: An iterator string handle created by _init_func
174      Returns:
175        Tensor constant 0
176      """
177      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
178          string_handle,
179          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
180      with ops.control_dependencies([
181          resource_variable_ops.destroy_resource_op(
182              iterator_resource, ignore_lookup_error=True)]):
183        return array_ops.constant(0, dtypes.int64)
184
185    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access
186
187    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
188    def _remote_finalize_func(string_handle):
189      return functional_ops.remote_call(
190          target=self._source_device,
191          args=[string_handle] + finalize_func_concrete.captured_inputs,
192          Tout=[dtypes.int64],
193          f=finalize_func_concrete)
194
195    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
196    )
197    self._finalize_captured_args = self._finalize_func.captured_inputs
198
199    g = ops.get_default_graph()
200    self._init_func.add_to_graph(g)
201    self._next_func.add_to_graph(g)
202    self._finalize_func.add_to_graph(g)
203    # pylint: enable=protected-scope
204
205    with ops.device(self._target_device):
206      variant_tensor = gen_dataset_ops.generator_dataset(
207          self._init_captured_args,
208          self._next_captured_args,
209          self._finalize_captured_args,
210          init_func=self._init_func,
211          next_func=self._next_func,
212          finalize_func=self._finalize_func,
213          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
214    super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
215
216  # The one_shot_iterator implementation needs a 0 arg _make_dataset function
217  # that thereby captures all the inputs required to create the dataset. Since
218  # there are strings that are inputs to the GeneratorDataset which can't be
219  # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
220  # GPU
221  def make_one_shot_iterator(self):
222    if self._is_gpu_target:
223      raise ValueError(
224          "`make_one_shot_iterator` is not compatible with GPU execution. "
225          "Please use `Dataset.make_initializable_iterator()` instead."
226      )
227    else:
228      return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
229
230
231class _MapOnGpuDataset(dataset_ops.UnaryDataset):
232  """A `Dataset` that maps a function over elements in its using a GPU."""
233
234  def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
235    """See `Dataset.map()` for details."""
236    self._input_dataset = input_dataset
237    self._use_inter_op_parallelism = use_inter_op_parallelism
238
239    self._map_func = structured_function.StructuredFunctionWrapper(
240        map_func,
241        self._transformation_name(),
242        dataset=input_dataset,
243        defun_kwargs={"experimental_ints_on_device": True})
244    variant_tensor = ged_ops.experimental_map_dataset(
245        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
246        self._map_func.function.captured_inputs,
247        f=self._map_func.function,
248        use_inter_op_parallelism=self._use_inter_op_parallelism,
249        **self._flat_structure)
250    super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
251
252  def _functions(self):
253    return [self._map_func]
254
255  @property
256  def element_spec(self):
257    return self._map_func.output_structure
258
259  def _transformation_name(self):
260    return "map_on_gpu()"
261
262
263def map_on_gpu(map_func):
264  """Maps `map_func` across the elements of this dataset.
265
266  NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
267  `map_func` on GPU. It must be used after applying the
268  `tf.data.experimental.copy_to_device` transformation with a GPU device
269  argument.
270
271  Args:
272    map_func: A function mapping a nested structure of tensors (having shapes
273      and types defined by `self.output_shapes` and `self.output_types`) to
274      another nested structure of tensors.
275
276  Returns:
277    A `Dataset` transformation function, which can be passed to
278    `tf.data.Dataset.apply`.
279  """
280
281  def _apply_fn(dataset):
282    return _MapOnGpuDataset(dataset, map_func)
283
284  return _apply_fn
285