1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from tensorflow.python.data.ops import dataset_ops 17from tensorflow.python.data.ops import iterator_ops 18from tensorflow.python.data.ops import options as options_lib 19from tensorflow.python.data.util import structure 20from tensorflow.python.eager import context 21from tensorflow.python.eager import function 22from tensorflow.python.framework import composite_tensor 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.framework import type_spec 28from tensorflow.python.framework import type_utils 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import functional_ops 32from tensorflow.python.ops import gen_dataset_ops 33from tensorflow.python.ops import resource_variable_ops 34 35 36class _PerDeviceGenerator(dataset_ops.DatasetV2): 37 """A `dummy` generator dataset.""" 38 39 def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, 40 source_device, element_spec, iterator_is_anonymous): 41 self._element_spec = element_spec 42 43 multi_device_iterator_string_handle = ( 44 gen_dataset_ops.multi_device_iterator_to_string_handle( 45 multi_device_iterator_resource)) 46 47 # TODO(b/124254153): Enable autograph once the overhead is low enough. 48 @function.defun(autograph=False) # Pure graph code. 49 def _init_func(): 50 return multi_device_iterator_string_handle 51 52 init_func_concrete = _init_func.get_concrete_function() 53 54 # TODO(b/124254153): Enable autograph once the overhead is low enough. 55 @function.defun(autograph=False) # Pure graph code. 56 def _remote_init_func(): 57 return functional_ops.remote_call( 58 target=source_device, 59 args=init_func_concrete.captured_inputs, 60 Tout=[dtypes.string], 61 f=init_func_concrete) 62 63 self._init_func = _remote_init_func.get_concrete_function() 64 self._init_captured_args = self._init_func.captured_inputs 65 66 # TODO(b/124254153): Enable autograph once the overhead is low enough. 67 @function.defun( 68 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 69 autograph=False) # Pure graph code. 70 def _next_func(string_handle): 71 # pylint: disable=protected-access 72 multi_device_iterator = ( 73 gen_dataset_ops.multi_device_iterator_from_string_handle( 74 string_handle=string_handle, 75 output_types=structure.get_flat_tensor_types(self._element_spec), 76 output_shapes=structure.get_flat_tensor_shapes( 77 self._element_spec))) 78 return gen_dataset_ops.multi_device_iterator_get_next_from_shard( 79 multi_device_iterator=multi_device_iterator, 80 shard_num=shard_num, 81 incarnation_id=incarnation_id, 82 output_types=structure.get_flat_tensor_types(self._element_spec), 83 output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) 84 85 next_func_concrete = _next_func.get_concrete_function() 86 87 # TODO(b/124254153): Enable autograph once the overhead is low enough. 88 @function.defun_with_attributes( 89 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 90 attributes={"experimental_ints_on_device": True}, 91 autograph=False) # Pure graph code. 92 def _remote_next_func(string_handle): 93 return_values = functional_ops.remote_call( 94 target=source_device, 95 args=[string_handle] + next_func_concrete.captured_inputs, 96 Tout=structure.get_flat_tensor_types(self._element_spec), 97 f=next_func_concrete) 98 # Add full type information to the graph so that the RemoteCall op 99 # can determine for each of its outputs whether or not they are ragged 100 # tensors (or other types that use variants) that contain strings 101 # (or other host memory types). Then RemoteCall can 102 # appropriately set AllocatorAttributes to control copies so 103 # strings/host memory types stay on CPU. 104 fulltype_list = type_utils.fulltypes_for_flat_tensors(self._element_spec) 105 fulltype = type_utils.fulltype_list_to_product(fulltype_list) 106 for return_value in return_values: 107 return_value.op.experimental_set_type(fulltype) 108 return return_values 109 110 self._next_func = _remote_next_func.get_concrete_function() 111 self._next_captured_args = self._next_func.captured_inputs 112 113 if iterator_is_anonymous: 114 self._next_captured_args = self._next_captured_args + [ 115 multi_device_iterator_resource 116 ] 117 118 self._incarnation_id_index = -1 119 for i, arg in enumerate(self._next_captured_args): 120 if arg is incarnation_id: 121 self._incarnation_id_index = i 122 123 # TODO(b/124254153): Enable autograph once the overhead is low enough. 124 @function.defun( 125 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 126 autograph=False) # Pure graph code. 127 def _finalize_func(unused_string_handle): 128 return array_ops.constant(0, dtypes.int64) 129 130 finalize_func_concrete = _finalize_func.get_concrete_function() 131 132 # TODO(b/124254153): Enable autograph once the overhead is low enough. 133 @function.defun( 134 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 135 autograph=False) # Pure graph code. 136 def _remote_finalize_func(string_handle): 137 return functional_ops.remote_call( 138 target=source_device, 139 args=[string_handle] + finalize_func_concrete.captured_inputs, 140 Tout=[dtypes.int64], 141 f=finalize_func_concrete) 142 143 self._finalize_func = _remote_finalize_func.get_concrete_function() 144 self._finalize_captured_args = self._finalize_func.captured_inputs 145 146 variant_tensor = gen_dataset_ops.generator_dataset( 147 self._init_captured_args, 148 self._next_captured_args, 149 self._finalize_captured_args, 150 init_func=self._init_func, 151 next_func=self._next_func, 152 finalize_func=self._finalize_func, 153 **self._flat_structure) 154 super(_PerDeviceGenerator, self).__init__(variant_tensor) 155 156 def _inputs(self): 157 # TODO(b/116506223): Determine which datasets should be used as inputs here. 158 return [] 159 160 @property 161 def element_spec(self): 162 return self._element_spec 163 164 165class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2): 166 """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id. 167 168 Re-uses the functions from the provided per_device_dataset and just switches 169 out the function argument corresponding to the incarnation_id. 170 """ 171 172 def __init__(self, per_device_dataset, incarnation_id): 173 # pylint: disable=protected-access 174 self._element_spec = per_device_dataset.element_spec 175 self._init_func = per_device_dataset._init_func 176 self._init_captured_args = self._init_func.captured_inputs 177 178 self._next_func = per_device_dataset._next_func 179 self._next_captured_args = per_device_dataset._next_captured_args 180 # The captured arguments to the next_func are string_handle, incarnation_id. 181 # We update the incarnation id to the new one. 182 self._next_captured_args[ 183 per_device_dataset._incarnation_id_index] = incarnation_id 184 185 self._finalize_func = per_device_dataset._finalize_func 186 self._finalize_captured_args = per_device_dataset._finalize_captured_args 187 188 variant_tensor = gen_dataset_ops.generator_dataset( 189 self._init_captured_args, 190 self._next_captured_args, 191 self._finalize_captured_args, 192 init_func=self._init_func, 193 next_func=self._next_func, 194 finalize_func=self._finalize_func, 195 **self._flat_structure) 196 super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor) 197 198 def _inputs(self): 199 # TODO(b/116506223): Determine which datasets should be used as inputs here. 200 return [] 201 202 @property 203 def element_spec(self): 204 return self._element_spec 205 206 207def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size, 208 experimental_slack): 209 """Uses _prototype_device_datasets[i] to build a dataset for the device.""" 210 ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id) 211 if prefetch_buffer_size > 0: 212 if experimental_slack: 213 ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1) 214 else: 215 ds = ds.prefetch(prefetch_buffer_size) 216 return ds 217 218 219class MultiDeviceIterator: 220 """An iterator over multiple devices.""" 221 222 def __init__(self, 223 dataset, 224 devices, 225 max_buffer_size=1, 226 prefetch_buffer_size=1, 227 source_device="/cpu:0"): 228 """Constructs a MultiDeviceIterator. 229 230 Args: 231 dataset: The input dataset to be iterated over. 232 devices: The list of devices to fetch data to. 233 max_buffer_size: Maximum size of the host side per device buffer to keep. 234 prefetch_buffer_size: if > 0, then we setup a buffer on each device to 235 prefetch into. 236 source_device: The host device to place the `dataset` on. In order to 237 prevent deadlocks, if the prefetch_buffer_size is greater than the 238 max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. 239 """ 240 options = options_lib.Options() 241 options.experimental_distribute.num_devices = len(devices) 242 dataset = dataset.with_options(options) 243 self._dataset = dataset._apply_debug_options() # pylint: disable=protected-access 244 self._experimental_slack = dataset.options().experimental_slack 245 self._devices = devices 246 self._source_device = source_device 247 self._source_device_tensor = ops.convert_to_tensor(source_device) 248 self._max_buffer_size = max_buffer_size 249 self._prefetch_buffer_size = prefetch_buffer_size 250 251 if self._prefetch_buffer_size > self._max_buffer_size: 252 self._max_buffer_size = self._prefetch_buffer_size 253 254 # Create the MultiDeviceIterator. 255 with ops.device(self._source_device): 256 # TODO(b/121378567): Get rid of this shared_name hack. 257 shared_name = "" 258 if context.executing_eagerly(): 259 shared_name = context.anonymous_name() 260 self._multi_device_iterator_resource = ( 261 gen_dataset_ops.multi_device_iterator( 262 devices=self._devices, 263 shared_name=shared_name, 264 container="", 265 **self._dataset._flat_structure)) # pylint: disable=protected-access 266 if context.executing_eagerly(): 267 # Delete the resource when this object is deleted 268 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 269 handle=self._multi_device_iterator_resource, 270 handle_device=self._source_device) 271 272 # The incarnation ID is used to ensure consistency between the per-device 273 # iterators and the multi-device iterator. 274 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 275 self._dataset._variant_tensor, # pylint: disable=protected-access 276 self._multi_device_iterator_resource, 277 max_buffer_size=self._max_buffer_size) 278 279 self._prototype_device_datasets = [] 280 for i, device in enumerate(self._devices): 281 with ops.device(device): 282 ds = _PerDeviceGenerator( 283 i, 284 self._multi_device_iterator_resource, 285 self._incarnation_id, 286 self._source_device_tensor, 287 self._dataset.element_spec, 288 iterator_is_anonymous=False) 289 self._prototype_device_datasets.append(ds) 290 291 # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to 292 # initialize the device side of the pipeline. This would allow the 293 # MultiDeviceIterator to choose, for example, to move some transformations 294 # into the device side from its input. It might be useful in rewriting. 295 # Create the per device iterators. 296 self._device_iterators = [] 297 for i, device in enumerate(self._devices): 298 with ops.device(device): 299 ds = _create_device_dataset(self._prototype_device_datasets[i], 300 self._incarnation_id, 301 self._prefetch_buffer_size, 302 self._experimental_slack) 303 if context.executing_eagerly(): 304 self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds)) 305 else: 306 self._device_iterators.append( 307 dataset_ops.make_initializable_iterator(ds)) 308 309 if not context.executing_eagerly(): 310 device_iterator_initializers = [ 311 iterator.initializer for iterator in self._device_iterators 312 ] 313 self._initializer = control_flow_ops.group(*device_iterator_initializers) 314 315 def get_next(self, device=None): 316 """Returns the next element given a `device`, else returns all in a list.""" 317 if device is not None: 318 index = self._devices.index(device) 319 return self._device_iterators[index].get_next() 320 321 result = [] 322 for i, device in enumerate(self._devices): 323 with ops.device(device): 324 result.append(self._device_iterators[i].get_next()) 325 return result 326 327 def get_next_as_optional(self): 328 result = [] 329 for i, device in enumerate(self._devices): 330 with ops.device(device): 331 result.append(self._device_iterators[i].get_next_as_optional()) 332 return result 333 334 @property 335 def initializer(self): 336 if context.executing_eagerly(): 337 return control_flow_ops.no_op() 338 return self._initializer 339 340 def _eager_reset(self): 341 """Resets the MultiDeviceIterator in eager mode.""" 342 if not ops.executing_eagerly_outside_functions(): 343 raise ValueError( 344 "Resetting a multi-device iterator is only supported in the eager " 345 "mode.") 346 # pylint: disable=protected-access 347 self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( 348 self._dataset._variant_tensor, 349 self._multi_device_iterator_resource, 350 max_buffer_size=self._max_buffer_size) 351 for i, device in enumerate(self._devices): 352 with ops.device(device): 353 ds = _create_device_dataset(self._prototype_device_datasets[i], 354 self._incarnation_id, 355 self._prefetch_buffer_size, 356 self._experimental_slack) 357 # Reset the device iterator resources with the new dataset. 358 ds_variant = ds._variant_tensor 359 gen_dataset_ops.make_iterator( 360 ds_variant, self._device_iterators[i]._iterator_resource) 361 362 @property 363 def element_spec(self): 364 return self._dataset.element_spec 365 366 367class MultiDeviceIteratorSpec(type_spec.TypeSpec): 368 """Type specification for `OwnedMultiDeviceIterator`.""" 369 370 __slots__ = ["_devices", "_source_device", "_element_spec"] 371 372 def __init__(self, devices, source_device, element_spec): 373 self._devices = devices 374 self._source_device = source_device 375 self._element_spec = element_spec 376 377 @property 378 def value_type(self): 379 return OwnedMultiDeviceIterator 380 381 def _serialize(self): 382 return (tuple(self._devices), self._source_device, self._element_spec) 383 384 @property 385 def _component_specs(self): 386 specs = [ 387 tensor_spec.TensorSpec([], dtypes.resource), 388 ] 389 for _ in range(len(self._devices)): 390 specs.append(iterator_ops.IteratorSpec(self._element_spec)) 391 return specs 392 393 def _to_components(self, value): 394 # pylint: disable=protected-access 395 c = [value._multi_device_iterator_resource] 396 c.extend(value._device_iterators) 397 return c 398 399 def _from_components(self, components): 400 return OwnedMultiDeviceIterator( 401 dataset=None, 402 devices=self._devices, 403 source_device=self._source_device, 404 components=components, 405 element_spec=self._element_spec) 406 407 @staticmethod 408 def from_value(value): 409 # pylint: disable=protected-access 410 return MultiDeviceIteratorSpec( 411 value._devices, 412 value._source_device, 413 value.element_spec) 414 415 416class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor): 417 """An iterator over multiple devices. 418 419 The multi-device iterator resource created through `OwnedMultiDeviceIterator` 420 is owned by the Python object and the life time of the underlying resource is 421 tied to the life time of the `OwnedMultiDeviceIterator` object. This makes 422 `OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of 423 tf.functions. 424 """ 425 426 def __init__(self, 427 dataset=None, 428 devices=None, 429 max_buffer_size=1, 430 prefetch_buffer_size=1, 431 source_device="/cpu:0", 432 components=None, 433 element_spec=None): 434 """Constructs an owned MultiDeviceIterator object. 435 436 Args: 437 dataset: The input dataset to be iterated over. 438 devices: (Required.) The list of devices to fetch data to. 439 max_buffer_size: Maximum size of the host side per device buffer to keep. 440 prefetch_buffer_size: if > 0, then we setup a buffer on each device to 441 prefetch into. 442 source_device: The host device to place the `dataset` on. In order to 443 prevent deadlocks, if the prefetch_buffer_size is greater than the 444 max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. 445 components: Tensor components to construct the MultiDeviceIterator from. 446 element_spec: A (nested) structure of `tf.TypeSpec` objects that 447 represents the type specification of elements of the iterator. 448 449 Raises: 450 RuntimeError: If executed in graph mode or outside of function building 451 mode. 452 ValueError: If any of the following happens: 453 - `devices` is `None` 454 - `dataset` is `None` and either `components` or `element_spec` is 455 `None` 456 - `dataset` is not None and either `components` or `element_spec` is 457 provided 458 """ 459 if not context.executing_eagerly() and not ops.inside_function(): 460 raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of " 461 "tf.function or when eager execution is enabled.") 462 if devices is None: 463 raise ValueError("`devices` must be provided.") 464 465 if dataset is None: 466 if (components is None or element_spec is None): 467 raise ValueError( 468 "When `dataset` is not provided, both `components` and " 469 "`element_spec` must be specified.") 470 self._element_spec = element_spec 471 self._devices = devices 472 self._source_device = source_device 473 self._multi_device_iterator_resource = components[0] 474 self._device_iterators = components[1:] 475 else: 476 if (components is not None or element_spec is not None): 477 raise ValueError( 478 "When `dataset` is provided, `element_spec` and `components` must " 479 "not be specified.") 480 options = options_lib.Options() 481 options.experimental_distribute.num_devices = len(devices) 482 dataset = dataset.with_options(options) 483 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 484 self._element_spec = dataset.element_spec 485 experimental_slack = dataset.options().experimental_slack 486 self._devices = devices 487 self._source_device = source_device 488 source_device_tensor = ops.convert_to_tensor(self._source_device) 489 490 if prefetch_buffer_size > max_buffer_size: 491 max_buffer_size = prefetch_buffer_size 492 493 # Create the MultiDeviceIterator. 494 with ops.device(self._source_device): 495 self._multi_device_iterator_resource = ( 496 gen_dataset_ops.anonymous_multi_device_iterator_v3( 497 devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access 498 499 # The incarnation ID is used to ensure consistency between the 500 # per-device iterators and the multi-device iterator. 501 incarnation_id = gen_dataset_ops.multi_device_iterator_init( 502 dataset._variant_tensor, # pylint: disable=protected-access 503 self._multi_device_iterator_resource, 504 max_buffer_size=max_buffer_size) 505 506 prototype_device_datasets = [] 507 for i, device in enumerate(self._devices): 508 with ops.device(device): 509 ds = _PerDeviceGenerator( 510 i, 511 self._multi_device_iterator_resource, 512 incarnation_id, 513 source_device_tensor, 514 dataset.element_spec, 515 iterator_is_anonymous=True, 516 ) 517 prototype_device_datasets.append(ds) 518 519 # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to 520 # initialize the device side of the pipeline. This would allow the 521 # MultiDeviceIterator to choose, for example, to move some transformations 522 # into the device side from its input. It might be useful in rewriting. 523 # Create the per device iterators. 524 self._device_iterators = [] 525 526 for i, device in enumerate(self._devices): 527 with ops.device(device): 528 ds = _create_device_dataset(prototype_device_datasets[i], 529 incarnation_id, prefetch_buffer_size, 530 experimental_slack) 531 iterator = iter(ds) 532 self._device_iterators.append(iterator) 533 534 def get_next(self, device=None): 535 """Returns the next element given a `device`, else returns all in a list.""" 536 if device is not None: 537 index = self._devices.index(device) 538 return self._device_iterators[index].get_next() 539 540 result = [] 541 for i, device in enumerate(self._devices): 542 with ops.device(device): 543 result.append(self._device_iterators[i].get_next()) 544 return result 545 546 def __iter__(self): 547 return self 548 549 def next(self): 550 return self.__next__() 551 552 def __next__(self): 553 try: 554 return self.get_next() 555 except errors.OutOfRangeError: 556 raise StopIteration 557 558 def get_next_as_optional(self): 559 result = [] 560 for i, device in enumerate(self._devices): 561 with ops.device(device): 562 result.append(self._device_iterators[i].get_next_as_optional()) 563 return result 564 565 @property 566 def element_spec(self): 567 return self._element_spec 568 569 @property 570 def _type_spec(self): 571 return MultiDeviceIteratorSpec(self._devices, self._source_device, 572 self._element_spec) 573