1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from tensorflow.python.data.ops import dataset_ops 17from tensorflow.python.data.ops import iterator_ops 18from tensorflow.python.data.ops import structured_function 19from tensorflow.python.data.util import structure 20from tensorflow.python.eager import function 21from tensorflow.python.framework import device as framework_device 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import functional_ops 27from tensorflow.python.ops import gen_dataset_ops 28from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.util.tf_export import tf_export 31 32 33@tf_export("data.experimental.prefetch_to_device") 34def prefetch_to_device(device, buffer_size=None): 35 """A transformation that prefetches dataset values to the given `device`. 36 37 NOTE: Although the transformation creates a `tf.data.Dataset`, the 38 transformation must be the final `Dataset` in the input pipeline. 39 40 For example, 41 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 42 >>> dataset = dataset.apply(tf.data.experimental.prefetch_to_device("/cpu:0")) 43 >>> for element in dataset: 44 ... print(f'Tensor {element} is on device {element.device}') 45 Tensor 1 is on device /job:localhost/replica:0/task:0/device:CPU:0 46 Tensor 2 is on device /job:localhost/replica:0/task:0/device:CPU:0 47 Tensor 3 is on device /job:localhost/replica:0/task:0/device:CPU:0 48 49 Args: 50 device: A string. The name of a device to which elements will be prefetched. 51 buffer_size: (Optional.) The number of elements to buffer on `device`. 52 Defaults to an automatically chosen value. 53 54 Returns: 55 A `Dataset` transformation function, which can be passed to 56 `tf.data.Dataset.apply`. 57 """ 58 def _apply_fn(dataset): 59 return dataset.apply( 60 copy_to_device(target_device=device)).prefetch(buffer_size) 61 62 return _apply_fn 63 64 65@tf_export("data.experimental.copy_to_device") 66def copy_to_device(target_device, source_device="/cpu:0"): 67 """A transformation that copies dataset elements to the given `target_device`. 68 69 Args: 70 target_device: The name of a device to which elements will be copied. 71 source_device: The original device on which `input_dataset` will be placed. 72 73 Returns: 74 A `Dataset` transformation function, which can be passed to 75 `tf.data.Dataset.apply`. 76 """ 77 78 def _apply_fn(dataset): 79 return _CopyToDeviceDataset( 80 dataset, target_device=target_device, source_device=source_device) 81 82 return _apply_fn 83 84 85# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate 86# all inputs to the Op are in host memory, thereby avoiding some unnecessary 87# Sends and Recvs. 88class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): 89 """A `Dataset` that copies elements to another device.""" 90 91 def __init__(self, input_dataset, target_device, source_device="/cpu:0"): 92 """Constructs a _CopyToDeviceDataset. 93 94 Args: 95 input_dataset: `Dataset` to be copied 96 target_device: The name of the device to which elements would be copied. 97 source_device: Device where input_dataset would be placed. 98 """ 99 self._input_dataset = input_dataset._apply_debug_options() # pylint: disable=protected-access 100 self._target_device = target_device 101 spec = framework_device.DeviceSpec().from_string(self._target_device) 102 self._is_gpu_target = (spec.device_type == "GPU") 103 self._source_device_string = source_device 104 self._source_device = ops.convert_to_tensor(source_device) 105 106 wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( 107 self._input_dataset._variant_tensor) # pylint: disable=protected-access 108 109 @function.defun() 110 def _init_func(): 111 """Creates an iterator for the input dataset. 112 113 Returns: 114 A `string` tensor that encapsulates the iterator created. 115 """ 116 ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) 117 resource = gen_dataset_ops.anonymous_iterator( 118 **self._input_dataset._flat_structure) # pylint: disable=protected-access 119 with ops.control_dependencies( 120 [gen_dataset_ops.make_iterator(ds_variant, resource)]): 121 return gen_dataset_ops.iterator_to_string_handle(resource) 122 123 init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access 124 125 @function.defun() 126 def _remote_init_func(): 127 return functional_ops.remote_call( 128 target=self._source_device, 129 args=init_func_concrete.captured_inputs, 130 Tout=[dtypes.string], 131 f=init_func_concrete) 132 133 self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access 134 self._init_captured_args = self._init_func.captured_inputs 135 136 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 137 def _next_func(string_handle): 138 """Calls get_next for created iterator. 139 140 Args: 141 string_handle: An iterator string handle created by _init_func 142 Returns: 143 The elements generated from `input_dataset` 144 """ 145 with ops.device(self._source_device_string): 146 iterator = iterator_ops.Iterator.from_string_handle( 147 string_handle, 148 dataset_ops.get_legacy_output_types(self), 149 dataset_ops.get_legacy_output_shapes(self), 150 dataset_ops.get_legacy_output_classes(self)) 151 return structure.to_tensor_list(self.element_spec, iterator.get_next()) 152 153 next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access 154 155 @function.defun_with_attributes( 156 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 157 attributes={"experimental_ints_on_device": True}) 158 def _remote_next_func(string_handle): 159 return functional_ops.remote_call( 160 target=self._source_device, 161 args=[string_handle] + next_func_concrete.captured_inputs, 162 Tout=self._input_dataset._flat_types, # pylint: disable=protected-access 163 f=next_func_concrete) 164 165 self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access 166 self._next_captured_args = self._next_func.captured_inputs 167 168 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 169 def _finalize_func(string_handle): 170 """Destroys the iterator resource created. 171 172 Args: 173 string_handle: An iterator string handle created by _init_func 174 Returns: 175 Tensor constant 0 176 """ 177 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 178 string_handle, 179 **self._input_dataset._flat_structure) # pylint: disable=protected-access 180 with ops.control_dependencies([ 181 resource_variable_ops.destroy_resource_op( 182 iterator_resource, ignore_lookup_error=True)]): 183 return array_ops.constant(0, dtypes.int64) 184 185 finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access 186 187 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 188 def _remote_finalize_func(string_handle): 189 return functional_ops.remote_call( 190 target=self._source_device, 191 args=[string_handle] + finalize_func_concrete.captured_inputs, 192 Tout=[dtypes.int64], 193 f=finalize_func_concrete) 194 195 self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access 196 ) 197 self._finalize_captured_args = self._finalize_func.captured_inputs 198 199 g = ops.get_default_graph() 200 self._init_func.add_to_graph(g) 201 self._next_func.add_to_graph(g) 202 self._finalize_func.add_to_graph(g) 203 # pylint: enable=protected-scope 204 205 with ops.device(self._target_device): 206 variant_tensor = gen_dataset_ops.generator_dataset( 207 self._init_captured_args, 208 self._next_captured_args, 209 self._finalize_captured_args, 210 init_func=self._init_func, 211 next_func=self._next_func, 212 finalize_func=self._finalize_func, 213 **self._input_dataset._flat_structure) # pylint: disable=protected-access 214 super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor) 215 216 # The one_shot_iterator implementation needs a 0 arg _make_dataset function 217 # that thereby captures all the inputs required to create the dataset. Since 218 # there are strings that are inputs to the GeneratorDataset which can't be 219 # placed on a GPU, this fails for the GPU case. Therefore, disabling it for 220 # GPU 221 def make_one_shot_iterator(self): 222 if self._is_gpu_target: 223 raise ValueError( 224 "`make_one_shot_iterator` is not compatible with GPU execution. " 225 "Please use `Dataset.make_initializable_iterator()` instead." 226 ) 227 else: 228 return super(_CopyToDeviceDataset, self).make_one_shot_iterator() 229 230 231class _MapOnGpuDataset(dataset_ops.UnaryDataset): 232 """A `Dataset` that maps a function over elements in its using a GPU.""" 233 234 def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): 235 """See `Dataset.map()` for details.""" 236 self._input_dataset = input_dataset 237 self._use_inter_op_parallelism = use_inter_op_parallelism 238 239 self._map_func = structured_function.StructuredFunctionWrapper( 240 map_func, 241 self._transformation_name(), 242 dataset=input_dataset, 243 defun_kwargs={"experimental_ints_on_device": True}) 244 variant_tensor = ged_ops.experimental_map_dataset( 245 self._input_dataset._variant_tensor, # pylint: disable=protected-access 246 self._map_func.function.captured_inputs, 247 f=self._map_func.function, 248 use_inter_op_parallelism=self._use_inter_op_parallelism, 249 **self._flat_structure) 250 super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor) 251 252 def _functions(self): 253 return [self._map_func] 254 255 @property 256 def element_spec(self): 257 return self._map_func.output_structure 258 259 def _transformation_name(self): 260 return "map_on_gpu()" 261 262 263def map_on_gpu(map_func): 264 """Maps `map_func` across the elements of this dataset. 265 266 NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs 267 `map_func` on GPU. It must be used after applying the 268 `tf.data.experimental.copy_to_device` transformation with a GPU device 269 argument. 270 271 Args: 272 map_func: A function mapping a nested structure of tensors (having shapes 273 and types defined by `self.output_shapes` and `self.output_types`) to 274 another nested structure of tensors. 275 276 Returns: 277 A `Dataset` transformation function, which can be passed to 278 `tf.data.Dataset.apply`. 279 """ 280 281 def _apply_fn(dataset): 282 return _MapOnGpuDataset(dataset, map_func) 283 284 return _apply_fn 285