1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Adapter module that convert different input data objects into tf.dataset.""" 16 17import abc 18import contextlib 19import functools 20import itertools 21import math 22import random 23 24import numpy as np 25 26from tensorflow.python.data.experimental.ops import cardinality 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.data.ops import iterator_ops 29from tensorflow.python.data.ops import options as options_lib 30from tensorflow.python.distribute import distribution_strategy_context as ds_context 31from tensorflow.python.distribute import input_lib 32from tensorflow.python.eager import context 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import smart_cond 37from tensorflow.python.framework import sparse_tensor 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.keras import backend 40from tensorflow.python.keras.engine import training_utils 41from tensorflow.python.keras.utils import data_utils 42from tensorflow.python.keras.utils import dataset_creator 43from tensorflow.python.keras.utils import tf_utils 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import random_ops 47from tensorflow.python.ops import script_ops 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.util import nest 50from tensorflow.python.util.tf_export import keras_export 51 52 53class DataAdapter(object, metaclass=abc.ABCMeta): 54 """Base class for input data adapter. 55 56 In TF 2.0, tf.data is the preferred API for user to feed in data. In order 57 to simplify the training code path, all the input data object will be 58 converted to `tf.data.Dataset` if possible. 59 60 Note that since this class is mainly targeted for TF 2.0, it might have a lot 61 of assumptions under the hood, eg eager context by default, distribution 62 strategy, etc. In the meantime, some legacy feature support might be dropped, 63 eg, Iterator from dataset API in v1, etc. 64 65 The sample usage of this class is like: 66 67 ``` 68 x = tf.data.Dataset.range(100) 69 adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter] 70 applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)] 71 if len(applicable_adapters) != 1: 72 raise ValueError("Expect only one adapter class to handle the input") 73 74 dataset = applicable_adapters[0](x).get_dataset() 75 for data in dataset: 76 # training 77 ``` 78 """ 79 80 @staticmethod 81 def can_handle(x, y=None): 82 """Whether the current DataAdapter could handle the input x and y. 83 84 Structure wise, x and y can be single object, or list of objects if there 85 multiple input/output, or dictionary of objects when the intput/output are 86 named. 87 88 Args: 89 x: input features. 90 y: target labels. Note that y could be None in the case of prediction. 91 92 Returns: 93 boolean 94 """ 95 raise NotImplementedError 96 97 @abc.abstractmethod 98 def __init__(self, x, y=None, **kwargs): 99 """Create a DataAdapter based on data inputs. 100 101 The caller must make sure to call `can_handle()` first before invoking this 102 method. Provide unsupported data type will result into unexpected behavior. 103 104 Args: 105 x: input features. 106 y: target labels. Note that y could be None in the case of prediction. 107 **kwargs: Other keyword arguments for DataAdapter during the construction 108 of the tf.dataset.Dataset. For example: 109 - Numpy data might have `sample_weights` which will be used for 110 weighting the loss function during training. 111 - Numpy data might need to have `batch_size` parameter when constructing 112 the dataset and iterator. 113 - Certain input might need to be distribution strategy aware. When 114 `distribution_strategy` is passed, the created dataset need to respect 115 the strategy. 116 DataAdapter might choose to ignore any keyword argument if it doesn't 117 use it, or raise exception if any required argument is not provide. 118 """ 119 if not self.can_handle(x, y): 120 raise ValueError("{} Cannot handle input {}, {}".format( 121 self.__class__, x, y)) 122 123 @abc.abstractmethod 124 def get_dataset(self): 125 """Get a dataset instance for the current DataAdapter. 126 127 Note that the dataset returned does not repeat for epoch, so caller might 128 need to create new iterator for the same dataset at the beginning of the 129 epoch. This behavior might change in future. 130 131 Returns: 132 An tf.dataset.Dataset. Caller might use the dataset in different 133 context, eg iter(dataset) in eager to get the value directly, or in graph 134 mode, provide the iterator tensor to Keras model function. 135 """ 136 raise NotImplementedError 137 138 @abc.abstractmethod 139 def get_size(self): 140 """Return the size (number of batches) for the dataset created. 141 142 For certain type of the data input, the number of batches is known, eg for 143 Numpy data, the size is same as (number_of_element / batch_size). Whereas 144 for dataset or python generator, the size is unknown since it may or may not 145 have a end state. 146 147 Returns: 148 int, the number of batches for the dataset, or None if it is unknown. The 149 caller could use this to control the loop of training, show progress bar, 150 or handle unexpected StopIteration error. 151 """ 152 raise NotImplementedError 153 154 @abc.abstractmethod 155 def batch_size(self): 156 """Return the batch size of the dataset created. 157 158 For certain type of the data input, the batch size is known, and even 159 required, like numpy array. Where as for dataset, the batch is unknown 160 unless we take a peek. 161 162 Returns: 163 int, the batch size of the dataset, or None if it is unknown. 164 """ 165 raise NotImplementedError 166 167 def representative_batch_size(self): 168 """Return a representative size for batches in the dataset. 169 170 This is not guaranteed to be the batch size for all batches in the 171 dataset. It just needs to be a rough approximation for batch sizes in 172 the dataset. 173 174 Returns: 175 int, a representative size for batches found in the dataset, 176 or None if it is unknown. 177 """ 178 return self.batch_size() 179 180 @abc.abstractmethod 181 def has_partial_batch(self): 182 """Whether the dataset has partial batch at the end.""" 183 raise NotImplementedError 184 185 @abc.abstractmethod 186 def partial_batch_size(self): 187 """The size of the final partial batch for dataset. 188 189 Will return None if has_partial_batch is False or batch_size is None. 190 """ 191 raise NotImplementedError 192 193 @abc.abstractmethod 194 def should_recreate_iterator(self): 195 """Returns whether a new iterator should be created every epoch.""" 196 raise NotImplementedError 197 198 def get_samples(self): 199 """Returns number of samples in the data, or `None`.""" 200 if not self.get_size() or not self.batch_size(): 201 return None 202 total_sample = self.get_size() * self.batch_size() 203 if self.has_partial_batch(): 204 total_sample -= (self.batch_size() - self.partial_batch_size()) 205 return total_sample 206 207 def on_epoch_end(self): 208 """A hook called after each epoch.""" 209 pass 210 211 212class TensorLikeDataAdapter(DataAdapter): 213 """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" 214 215 @staticmethod 216 def can_handle(x, y=None): 217 # TODO(kaftan): Check performance implications of using a flatten 218 # here for other types of inputs. 219 flat_inputs = nest.flatten(x) 220 if y is not None: 221 flat_inputs += nest.flatten(y) 222 223 tensor_types = _get_tensor_types() 224 225 def _is_tensor(v): 226 if isinstance(v, tensor_types): 227 return True 228 return False 229 230 return all(_is_tensor(v) for v in flat_inputs) 231 232 def __init__(self, 233 x, 234 y=None, 235 sample_weights=None, 236 sample_weight_modes=None, 237 batch_size=None, 238 epochs=1, 239 steps=None, 240 shuffle=False, 241 **kwargs): 242 super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs) 243 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 244 sample_weight_modes = broadcast_sample_weight_modes( 245 sample_weights, sample_weight_modes) 246 247 # If sample_weights are not specified for an output use 1.0 as weights. 248 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 249 y, sample_weights, sample_weight_modes, check_all_flat=True) 250 251 inputs = pack_x_y_sample_weight(x, y, sample_weights) 252 253 num_samples = set(int(i.shape[0]) for i in nest.flatten(inputs)).pop() 254 _check_data_cardinality(inputs) 255 256 # If batch_size is not passed but steps is, calculate from the input data. 257 # Default to 32 for backwards compat. 258 if not batch_size: 259 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 260 261 self._size = int(math.ceil(num_samples / batch_size)) 262 self._batch_size = batch_size 263 264 num_full_batches = int(num_samples // batch_size) 265 self._partial_batch_size = num_samples % batch_size 266 267 if isinstance(shuffle, str): 268 shuffle = shuffle.lower() 269 270 self._shuffle = shuffle 271 # Vectorized version of shuffle. 272 # This is a performance improvement over using `from_tensor_slices`. 273 # The indices of the data are shuffled and batched, and these indices 274 # are then zipped with the data and used to extract a batch of the data 275 # at each step. The performance improvements here come from: 276 # 1. vectorized batch using gather 277 # 2. parallelized map 278 # 3. pipelined permutation generation 279 # 4. optimized permutation batching 280 # 5. disabled static optimizations 281 282 indices_dataset = dataset_ops.DatasetV2.range(1) 283 if shuffle != "batch": 284 indices_dataset = indices_dataset.repeat(epochs) 285 286 def permutation(_): 287 # It turns out to be more performant to make a new set of indices rather 288 # than reusing the same range Tensor. (presumably because of buffer 289 # forwarding.) 290 indices = math_ops.range(num_samples, dtype=dtypes.int64) 291 if shuffle and shuffle != "batch": 292 indices = random_ops.random_shuffle(indices) 293 return indices 294 295 # We prefetch a single element. Computing large permutations can take quite 296 # a while so we don't want to wait for prefetching over an epoch boundary to 297 # trigger the next permutation. On the other hand, too many simultaneous 298 # shuffles can contend on a hardware level and degrade all performance. 299 indices_dataset = indices_dataset.map(permutation).prefetch(1) 300 301 def slice_batch_indices(indices): 302 """Convert a Tensor of indices into a dataset of batched indices. 303 304 This step can be accomplished in several ways. The most natural is to 305 slice the Tensor in a Dataset map. (With a condition on the upper index to 306 handle the partial batch.) However it turns out that coercing the Tensor 307 into a shape which is divisible by the batch size (and handling the last 308 partial batch separately) allows for a much more favorable memory access 309 pattern and improved performance. 310 311 Args: 312 indices: Tensor which determines the data order for an entire epoch. 313 314 Returns: 315 A Dataset of batched indices. 316 """ 317 num_in_full_batch = num_full_batches * batch_size 318 first_k_indices = array_ops.slice(indices, [0], [num_in_full_batch]) 319 first_k_indices = array_ops.reshape( 320 first_k_indices, [num_full_batches, batch_size]) 321 322 flat_dataset = dataset_ops.DatasetV2.from_tensor_slices(first_k_indices) 323 if self._partial_batch_size: 324 index_remainder = dataset_ops.DatasetV2.from_tensors(array_ops.slice( 325 indices, [num_in_full_batch], [self._partial_batch_size])) 326 flat_dataset = flat_dataset.concatenate(index_remainder) 327 328 if shuffle == "batch": 329 # 1024 is a magic constant that has not been properly evaluated 330 flat_dataset = flat_dataset.shuffle(1024).repeat(epochs) 331 return flat_dataset 332 333 indices_dataset = indices_dataset.flat_map(slice_batch_indices) 334 335 dataset = self.slice_inputs(indices_dataset, inputs) 336 337 if shuffle == "batch": 338 def shuffle_batch(*batch): 339 return nest.map_structure(random_ops.random_shuffle, batch) 340 dataset = dataset.map(shuffle_batch) 341 342 self._dataset = dataset 343 344 def slice_inputs(self, indices_dataset, inputs): 345 """Slice inputs into a Dataset of batches. 346 347 Given a Dataset of batch indices and the unsliced inputs, 348 this step slices the inputs in a parallelized fashion 349 and produces a dataset of input batches. 350 351 Args: 352 indices_dataset: A Dataset of batched indices 353 inputs: A python data structure that contains the inputs, targets, 354 and possibly sample weights. 355 356 Returns: 357 A Dataset of input batches matching the batch indices. 358 """ 359 dataset = dataset_ops.DatasetV2.zip(( 360 indices_dataset, 361 dataset_ops.DatasetV2.from_tensors(inputs).repeat() 362 )) 363 364 def grab_batch(i, data): 365 return nest.map_structure(lambda d: array_ops.gather(d, i, axis=0), data) 366 367 dataset = dataset.map( 368 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE) 369 370 # Default optimizations are disabled to avoid the overhead of (unnecessary) 371 # input pipeline graph serialization and deserialization 372 options = options_lib.Options() 373 options.experimental_optimization.apply_default_optimizations = False 374 if self._shuffle: 375 # See b/141490660 for more details. 376 options.experimental_external_state_policy = ( 377 options_lib.ExternalStatePolicy.IGNORE) 378 dataset = dataset.with_options(options) 379 return dataset 380 381 def get_dataset(self): 382 return self._dataset 383 384 def get_size(self): 385 return self._size 386 387 def batch_size(self): 388 return self._batch_size 389 390 def has_partial_batch(self): 391 return self._partial_batch_size > 0 392 393 def partial_batch_size(self): 394 return self._partial_batch_size or None 395 396 def should_recreate_iterator(self): 397 # An infinite dataset is always created here. 398 return False 399 400 401class GenericArrayLikeDataAdapter(TensorLikeDataAdapter): 402 """Adapter that handles array-like data without forcing it into memory. 403 404 This adapter handles array-like datasets that may be too big to fully 405 fit into memory. 406 407 Specifically, this adapter handles any Python class which implements: 408 `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings 409 as Numpy, but it ignores any case where all the inputs are Tensors or Numpy 410 arrays (because that case is handled by the base TensorLikeDataAdapter). 411 412 It ignores scipy sparse matrices and Composite Tensors because those are 413 handled by the CompositeTensorDataAdapter. 414 415 It also does not handle lists/tuples of scalars, because those are handled 416 by the ListsOfScalarsDataAdapter. 417 """ 418 419 @staticmethod 420 def can_handle(x, y=None): 421 flat_inputs = nest.flatten(x) 422 if y is not None: 423 flat_inputs += nest.flatten(y) 424 425 def _is_array_like(v): 426 """Return True if v is a Tensor, array, or is array-like.""" 427 return ( 428 hasattr(v, "__getitem__") and 429 hasattr(v, "shape") and 430 hasattr(v, "dtype") and 431 hasattr(v, "__len__") 432 ) 433 434 if (not TensorLikeDataAdapter.can_handle(x, y) and 435 not CompositeTensorDataAdapter.can_handle(x, y)): 436 return all(_is_array_like(v) for v in flat_inputs) 437 else: 438 return False 439 440 def __init__(self, *args, **kwargs): 441 logging.warning( 442 "Keras is training/fitting/evaluating on array-like data. Keras may " 443 "not be optimized for this format, so if your input data format is " 444 "supported by TensorFlow I/O (https://github.com/tensorflow/io) we " 445 "recommend using that to load a Dataset instead.") 446 447 super(GenericArrayLikeDataAdapter, self).__init__(*args, **kwargs) 448 449 def slice_inputs(self, indices_dataset, inputs): 450 """Slice inputs into a Dataset of batches. 451 452 Given a Dataset of batch indices and the unsliced inputs, 453 this step slices the inputs in a parallelized fashion 454 and produces a dataset of input batches. 455 456 Args: 457 indices_dataset: A Dataset of batched indices 458 inputs: A python data structure that contains the inputs, targets, 459 and possibly sample weights. 460 461 Returns: 462 A Dataset of input batches matching the batch indices. 463 """ 464 flat_inputs = nest.flatten(inputs) 465 def dynamic_shape_like(t): 466 shape = list(t.shape) 467 shape[0] = None 468 return tuple(shape) 469 470 flat_dtypes = [inp.dtype for inp in flat_inputs] 471 contiguous = True 472 if self._shuffle and self._shuffle != "batch": 473 contiguous = False 474 475 def grab_batch(indices): 476 """Grab a batch of data from the inputs.""" 477 # This uses a py_function to avoid converting the array-like 478 # into a Tensor before slicing it, because converting the array-like 479 # to a Tensor may force it into memory.. 480 def py_method(ind): 481 def slice_array(data): 482 return training_utils.slice_arrays(data, ind.numpy(), 483 contiguous=contiguous) 484 return [slice_array(inp) for inp in flat_inputs] 485 486 flat_out = script_ops.eager_py_func(py_method, [indices], flat_dtypes) 487 for v, original_inp in zip(flat_out, flat_inputs): 488 v.set_shape(dynamic_shape_like(original_inp)) 489 return nest.pack_sequence_as(inputs, flat_out) 490 491 dataset = indices_dataset.map( 492 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE) 493 494 return dataset 495 496 497class DatasetCreatorAdapter(DataAdapter): 498 """Adapter that handles dataset functions.""" 499 500 def __init__(self, x, y, steps=None, distribution_strategy=None, **kwargs): 501 super(DatasetCreatorAdapter, self).__init__(x, **kwargs) 502 503 if not isinstance(x, dataset_creator.DatasetCreator): 504 raise TypeError("The input of a `DatasetCreatorAdapter` should be a " 505 "`DatasetCreator` but it received type {}.".format( 506 type(x))) 507 if steps is None: 508 raise ValueError("When using a " 509 "`tf.keras.utils.experimental.DatasetCreator`, " 510 "`steps_per_epoch`, `validation_steps` or `steps` " 511 "argument must be provided in `Model.fit`, " 512 "`Model.evaluate`, or `Model.predict`.") 513 self.dataset_creator = x 514 self.steps = steps 515 self.strategy = distribution_strategy 516 517 @staticmethod 518 def can_handle(x, y=None): 519 if isinstance(x, dataset_creator.DatasetCreator): 520 assert y is None 521 return True 522 523 def should_recreate_iterator(self): 524 # We expect users to shuffle the dataset in their `dataset_fn` supplied to 525 # `DatasetCreator`. Since that is a buffered shuffle, we intend to not reset 526 # the dataset so the batches that are not shuffled can still be pulled. 527 return False 528 529 def get_size(self): 530 return None # To be inferred by `DataHandler`. 531 532 def get_dataset(self): 533 return self.strategy.distribute_datasets_from_function( 534 self.dataset_creator, options=self.dataset_creator.input_options) 535 536 def batch_size(self): 537 raise NotImplementedError() 538 539 def has_partial_batch(self): 540 raise NotImplementedError() 541 542 def partial_batch_size(self): 543 raise NotImplementedError() 544 545 546class CompositeTensorDataAdapter(DataAdapter): 547 """Adapter that handles composite tensor.""" 548 549 @staticmethod 550 def can_handle(x, y=None): 551 flat_inputs = nest.flatten(x) 552 if y is not None: 553 flat_inputs += nest.flatten(y) 554 555 def _is_composite(v): 556 # Dataset/iterator/DistributedDataset inherits from CompositeTensor but 557 # should be handled by DatasetAdapter and GeneratorAdapter. 558 if (tf_utils.is_extension_type(v) and 559 not isinstance(v, 560 (dataset_ops.DatasetV2, iterator_ops.IteratorBase)) and 561 not _is_distributed_dataset(v)): 562 return True 563 # Support Scipy sparse tensors if scipy is installed 564 return _is_scipy_sparse(v) 565 566 def _is_tensor_or_composite(v): 567 if isinstance(v, (ops.Tensor, np.ndarray)): 568 return True 569 return _is_composite(v) 570 571 return (any(_is_composite(v) for v in flat_inputs) and 572 all(_is_tensor_or_composite(v) for v in flat_inputs)) 573 574 def __init__(self, 575 x, 576 y=None, 577 sample_weights=None, 578 sample_weight_modes=None, 579 batch_size=None, 580 steps=None, 581 shuffle=False, 582 **kwargs): 583 super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs) 584 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 585 sample_weight_modes = broadcast_sample_weight_modes( 586 sample_weights, sample_weight_modes) 587 588 # If sample_weights are not specified for an output use 1.0 as weights. 589 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 590 y, sample_weights, sample_weight_modes, check_all_flat=True) 591 592 inputs = pack_x_y_sample_weight(x, y, sample_weights) 593 594 dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs) 595 num_samples = int(nest.flatten(x)[0].shape[0]) 596 if shuffle: 597 dataset = dataset.shuffle(num_samples) 598 599 # If batch_size is not passed but steps is, calculate from the input data. 600 # Default to 32 for backwards compat. 601 if not batch_size: 602 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 603 604 dataset = dataset.batch(batch_size) 605 self._size = int(math.ceil(num_samples / batch_size)) 606 self._batch_size = batch_size 607 self._has_partial_batch = (self._size != (num_samples // batch_size)) 608 609 self._partial_batch_size = None 610 if self._has_partial_batch: 611 self._partial_batch_size = ( 612 num_samples - (self._size - 1) * self._batch_size) 613 614 self._dataset = dataset 615 616 def get_dataset(self): 617 return self._dataset 618 619 def get_size(self): 620 return self._size 621 622 def batch_size(self): 623 return self._batch_size 624 625 def has_partial_batch(self): 626 return self._has_partial_batch 627 628 def partial_batch_size(self): 629 return self._partial_batch_size 630 631 def should_recreate_iterator(self): 632 return True 633 634 635class ListsOfScalarsDataAdapter(DataAdapter): 636 """Adapter that handles lists of scalars and lists of lists of scalars.""" 637 638 @staticmethod 639 def can_handle(x, y=None): 640 handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x) 641 handles_y = True 642 if y is not None: 643 handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y) 644 return handles_x and handles_y 645 646 @staticmethod 647 def _is_list_of_scalars(inp): 648 if isinstance(inp, (float, int, str, bytes, bytearray)): 649 return True 650 if isinstance(inp, (list, tuple)) and inp: 651 return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0]) 652 return False 653 654 def __init__(self, 655 x, 656 y=None, 657 sample_weights=None, 658 sample_weight_modes=None, 659 batch_size=None, 660 shuffle=False, 661 **kwargs): 662 super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs) 663 x = np.asarray(x) 664 if y is not None: 665 y = np.asarray(y) 666 if sample_weights is not None: 667 sample_weights = np.asarray(sample_weights) 668 sample_weight_modes = broadcast_sample_weight_modes( 669 sample_weights, sample_weight_modes) 670 671 self._internal_adapter = TensorLikeDataAdapter( 672 x, 673 y=y, 674 sample_weights=sample_weights, 675 sample_weight_modes=sample_weight_modes, 676 batch_size=batch_size, 677 shuffle=shuffle, 678 **kwargs) 679 680 def get_dataset(self): 681 return self._internal_adapter.get_dataset() 682 683 def get_size(self): 684 return self._internal_adapter.get_size() 685 686 def batch_size(self): 687 return self._internal_adapter.batch_size() 688 689 def has_partial_batch(self): 690 return self._internal_adapter.has_partial_batch() 691 692 def partial_batch_size(self): 693 return self._internal_adapter.partial_batch_size() 694 695 def should_recreate_iterator(self): 696 return True 697 698 699class DatasetAdapter(DataAdapter): 700 """Adapter that handles `tf.data.Dataset`.""" 701 702 @staticmethod 703 def can_handle(x, y=None): 704 return (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) or 705 _is_distributed_dataset(x)) 706 707 def __init__(self, 708 x, 709 y=None, 710 sample_weights=None, 711 steps=None, 712 **kwargs): 713 super(DatasetAdapter, self).__init__(x, y, **kwargs) 714 # Note that the dataset instance is immutable, its fine to reuse the user 715 # provided dataset. 716 self._dataset = x 717 718 # The user-provided steps. 719 self._user_steps = steps 720 721 self._validate_args(y, sample_weights, steps) 722 723 def get_dataset(self): 724 return self._dataset 725 726 def get_size(self): 727 return # Inferred in `DataHandler`. 728 729 def batch_size(self): 730 return None 731 732 def has_partial_batch(self): 733 return False 734 735 def partial_batch_size(self): 736 return None 737 738 def should_recreate_iterator(self): 739 # Since DistributedDatasets have no cardinality, the user must provide 740 # all steps that need to be run, calling `.repeat()` as needed. 741 if _is_distributed_dataset(self._dataset): 742 return False 743 744 # If user doesn't supply `steps`, or if they supply `steps` that 745 # exactly equals the size of the `Dataset`, create a new iterator 746 # each epoch. 747 return (self._user_steps is None or 748 cardinality.cardinality(self._dataset).numpy() == self._user_steps) 749 750 def _validate_args(self, y, sample_weights, steps): 751 """Validates `__init__` arguments.""" 752 # Arguments that shouldn't be passed. 753 if not is_none_or_empty(y): 754 raise ValueError("`y` argument is not supported when using " 755 "dataset as input.") 756 if not is_none_or_empty(sample_weights): 757 raise ValueError("`sample_weight` argument is not supported when using " 758 "dataset as input.") 759 760 if steps is None: 761 if _is_distributed_dataset(self._dataset): 762 raise ValueError("When providing a distributed dataset, you must " 763 "specify the number of steps to run.") 764 765 size = cardinality.cardinality(self._dataset).numpy() 766 if size == cardinality.INFINITE and steps is None: 767 raise ValueError( 768 "When providing an infinite dataset, you must specify " 769 "the number of steps to run (if you did not intend to " 770 "create an infinite dataset, make sure to not call " 771 "`repeat()` on the dataset).") 772 773 774class GeneratorDataAdapter(DataAdapter): 775 """Adapter that handles python generators and iterators.""" 776 777 @staticmethod 778 def can_handle(x, y=None): 779 return ((hasattr(x, "__next__") or hasattr(x, "next")) 780 and hasattr(x, "__iter__") 781 and not isinstance(x, data_utils.Sequence)) 782 783 def __init__(self, 784 x, 785 y=None, 786 sample_weights=None, 787 workers=1, 788 use_multiprocessing=False, 789 max_queue_size=10, 790 model=None, 791 **kwargs): 792 # Generators should never shuffle as exhausting the generator in order to 793 # shuffle the batches is inefficient. 794 kwargs.pop("shuffle", None) 795 796 if not is_none_or_empty(y): 797 raise ValueError("`y` argument is not supported when using " 798 "python generator as input.") 799 if not is_none_or_empty(sample_weights): 800 raise ValueError("`sample_weight` argument is not supported when using " 801 "python generator as input.") 802 803 super(GeneratorDataAdapter, self).__init__(x, y, **kwargs) 804 805 # Since we have to know the dtype of the python generator when we build the 806 # dataset, we have to look at a batch to infer the structure. 807 peek, x = self._peek_and_restore(x) 808 peek = self._standardize_batch(peek) 809 peek = _process_tensorlike(peek) 810 811 # Need to build the Model on concrete input shapes. 812 if model is not None and not model.built: 813 concrete_x, _, _ = unpack_x_y_sample_weight(peek) 814 model.distribute_strategy.run( 815 lambda x: model(x, training=False), args=(concrete_x,)) 816 817 self._first_batch_size = int(nest.flatten(peek)[0].shape[0]) 818 819 def _get_dynamic_shape(t): 820 shape = t.shape 821 # Unknown number of dimensions, `as_list` cannot be called. 822 if shape.rank is None: 823 return shape 824 return tensor_shape.TensorShape([None for _ in shape.as_list()]) 825 826 output_shapes = nest.map_structure(_get_dynamic_shape, peek) 827 output_types = nest.map_structure(lambda t: t.dtype, peek) 828 829 # Note that dataset API takes a callable that creates a generator object, 830 # rather than generator itself, which is why we define a function here. 831 generator_fn = self._handle_multiprocessing(x, workers, use_multiprocessing, 832 max_queue_size) 833 834 def wrapped_generator(): 835 for data in generator_fn(): 836 yield self._standardize_batch(data) 837 838 dataset = dataset_ops.DatasetV2.from_generator( 839 wrapped_generator, output_types, output_shapes=output_shapes) 840 841 if workers == 1 and not use_multiprocessing: 842 dataset = dataset.prefetch(1) 843 844 self._dataset = dataset 845 846 def _standardize_batch(self, data): 847 """Standardizes a batch output by a generator.""" 848 # Removes `None`s. 849 x, y, sample_weight = unpack_x_y_sample_weight(data) 850 data = pack_x_y_sample_weight(x, y, sample_weight) 851 852 data = nest.list_to_tuple(data) 853 854 def _convert_dtype(t): 855 if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)): 856 return np.array(t, dtype=backend.floatx()) 857 return t 858 859 data = nest.map_structure(_convert_dtype, data) 860 return data 861 862 @staticmethod 863 def _peek_and_restore(x): 864 peek = next(x) 865 return peek, itertools.chain([peek], x) 866 867 def _handle_multiprocessing(self, x, workers, use_multiprocessing, 868 max_queue_size): 869 """Create a callable, possibly including an Enqueuer.""" 870 if workers > 1 or (workers > 0 and use_multiprocessing): 871 def generator_fn(): 872 enqueuer = data_utils.GeneratorEnqueuer( 873 x, use_multiprocessing=use_multiprocessing) 874 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 875 return enqueuer.get() 876 else: 877 generator_fn = lambda: x 878 return generator_fn 879 880 def get_dataset(self): 881 return self._dataset 882 883 def get_size(self): 884 return None 885 886 def batch_size(self): 887 return None 888 889 def representative_batch_size(self): 890 return self._first_batch_size 891 892 def has_partial_batch(self): 893 return False 894 895 def partial_batch_size(self): 896 return 897 898 def should_recreate_iterator(self): 899 return False 900 901 902class KerasSequenceAdapter(GeneratorDataAdapter): 903 """Adapter that handles `keras.utils.Sequence`.""" 904 905 @staticmethod 906 def can_handle(x, y=None): 907 return isinstance(x, data_utils.Sequence) 908 909 def __init__(self, 910 x, 911 y=None, 912 sample_weights=None, 913 shuffle=False, 914 workers=1, 915 use_multiprocessing=False, 916 max_queue_size=10, 917 model=None, 918 **kwargs): 919 if not is_none_or_empty(y): 920 raise ValueError("`y` argument is not supported when using " 921 "`keras.utils.Sequence` as input.") 922 if not is_none_or_empty(sample_weights): 923 raise ValueError("`sample_weight` argument is not supported when using " 924 "`keras.utils.Sequence` as input.") 925 926 self._size = len(x) 927 self._shuffle_sequence = shuffle 928 self._keras_sequence = x 929 self._enqueuer = None 930 super(KerasSequenceAdapter, self).__init__( 931 x, 932 shuffle=False, # Shuffle is handed in the _make_callable override. 933 workers=workers, 934 use_multiprocessing=use_multiprocessing, 935 max_queue_size=max_queue_size, 936 model=model, 937 **kwargs) 938 939 @staticmethod 940 def _peek_and_restore(x): 941 return x[0], x 942 943 def _handle_multiprocessing(self, x, workers, use_multiprocessing, 944 max_queue_size): 945 if workers > 1 or (workers > 0 and use_multiprocessing): 946 def generator_fn(): 947 self._enqueuer = data_utils.OrderedEnqueuer( 948 x, use_multiprocessing=use_multiprocessing, 949 shuffle=self._shuffle_sequence) 950 self._enqueuer.start(workers=workers, max_queue_size=max_queue_size) 951 return self._enqueuer.get() 952 else: 953 def generator_fn(): 954 order = range(len(x)) 955 if self._shuffle_sequence: 956 # Match the shuffle convention in OrderedEnqueuer. 957 order = list(order) 958 random.shuffle(order) 959 960 for i in order: 961 yield x[i] 962 963 return generator_fn 964 965 def get_size(self): 966 return self._size 967 968 def should_recreate_iterator(self): 969 return True 970 971 def on_epoch_end(self): 972 if self._enqueuer: 973 self._enqueuer.stop() 974 self._keras_sequence.on_epoch_end() 975 976 977ALL_ADAPTER_CLS = [ 978 ListsOfScalarsDataAdapter, TensorLikeDataAdapter, 979 GenericArrayLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter, 980 KerasSequenceAdapter, CompositeTensorDataAdapter, DatasetCreatorAdapter 981] 982 983 984def select_data_adapter(x, y): 985 """Selects a data adapter than can handle a given x and y.""" 986 adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)] 987 if not adapter_cls: 988 # TODO(scottzhu): This should be a less implementation-specific error. 989 raise ValueError( 990 "Failed to find data adapter that can handle " 991 "input: {}, {}".format( 992 _type_name(x), _type_name(y))) 993 elif len(adapter_cls) > 1: 994 raise RuntimeError( 995 "Data adapters should be mutually exclusive for " 996 "handling inputs. Found multiple adapters {} to handle " 997 "input: {}, {}".format( 998 adapter_cls, _type_name(x), _type_name(y))) 999 return adapter_cls[0] 1000 1001 1002def _type_name(x): 1003 """Generates a description of the type of an object.""" 1004 if isinstance(x, dict): 1005 key_types = set(_type_name(key) for key in x.keys()) 1006 val_types = set(_type_name(key) for key in x.values()) 1007 return "({} containing {} keys and {} values)".format( 1008 type(x), key_types, val_types) 1009 if isinstance(x, (list, tuple)): 1010 types = set(_type_name(val) for val in x) 1011 return "({} containing values of types {})".format( 1012 type(x), types) 1013 return str(type(x)) 1014 1015 1016def _process_tensorlike(inputs): 1017 """Process tensor-like inputs. 1018 1019 This function: 1020 1021 (1) Converts `Numpy` arrays to `Tensor`s. 1022 (2) Converts `Scipy` sparse matrices to `SparseTensor`s. 1023 (2) Converts `list`s to `tuple`s (for `tf.data` support). 1024 1025 Args: 1026 inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like. 1027 1028 Returns: 1029 Structure of `Tensor`s or tensor-like. 1030 """ 1031 1032 def _convert_numpy_and_scipy(x): 1033 if isinstance(x, np.ndarray): 1034 dtype = None 1035 if issubclass(x.dtype.type, np.floating): 1036 dtype = backend.floatx() 1037 return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype) 1038 elif _is_scipy_sparse(x): 1039 return _scipy_sparse_to_sparse_tensor(x) 1040 return x 1041 1042 inputs = nest.map_structure(_convert_numpy_and_scipy, inputs) 1043 return nest.list_to_tuple(inputs) 1044 1045 1046def is_none_or_empty(inputs): 1047 # util method to check if the input is a None or a empty list. 1048 # the python "not" check will raise an error like below if the input is a 1049 # numpy array 1050 # "The truth value of an array with more than one element is ambiguous. 1051 # Use a.any() or a.all()" 1052 return inputs is None or not nest.flatten(inputs) 1053 1054 1055def broadcast_sample_weight_modes(target_structure, sample_weight_modes): 1056 """Match sample_weight_modes structure with output structure.""" 1057 if target_structure is None or not nest.flatten(target_structure): 1058 return sample_weight_modes 1059 1060 if isinstance(sample_weight_modes, str): 1061 if isinstance(target_structure, dict): 1062 return {key: sample_weight_modes for key in target_structure.keys()} 1063 return [sample_weight_modes for _ in target_structure] 1064 1065 if sample_weight_modes: 1066 try: 1067 nest.assert_same_structure( 1068 training_utils.list_to_tuple(target_structure), 1069 training_utils.list_to_tuple(sample_weight_modes)) 1070 except (ValueError, TypeError): 1071 target_str = str(nest.map_structure(lambda _: "...", target_structure)) 1072 mode_str = str(nest.map_structure(lambda _: "...", sample_weight_modes)) 1073 1074 # Attempt to coerce sample_weight_modes to the target structure. This 1075 # implicitly depends on the fact that Model flattens outputs for its 1076 # internal representation. 1077 try: 1078 sample_weight_modes = nest.pack_sequence_as( 1079 target_structure, nest.flatten(sample_weight_modes)) 1080 logging.warning( 1081 "sample_weight modes were coerced from\n {}\n to \n {}" 1082 .format(target_str, mode_str)) 1083 except (ValueError, TypeError): 1084 raise ValueError( 1085 "Unable to match target structure and sample_weight_modes " 1086 "structure:\n {}\n to \n {}".format(target_str, mode_str)) 1087 1088 return sample_weight_modes 1089 1090 1091class DataHandler(object): 1092 """Handles iterating over epoch-level `tf.data.Iterator` objects.""" 1093 1094 def __init__(self, 1095 x, 1096 y=None, 1097 sample_weight=None, 1098 batch_size=None, 1099 steps_per_epoch=None, 1100 initial_epoch=0, 1101 epochs=1, 1102 shuffle=False, 1103 class_weight=None, 1104 max_queue_size=10, 1105 workers=1, 1106 use_multiprocessing=False, 1107 model=None, 1108 steps_per_execution=None, 1109 distribute=True): 1110 """Initializes a `DataHandler`. 1111 1112 Arguments: 1113 x: See `Model.fit`. 1114 y: See `Model.fit`. 1115 sample_weight: See `Model.fit`. 1116 batch_size: See `Model.fit`. 1117 steps_per_epoch: See `Model.fit`. 1118 initial_epoch: See `Model.fit`. 1119 epochs: See `Model.fit`. 1120 shuffle: See `Model.fit`. 1121 class_weight: See `Model.fit`. 1122 max_queue_size: See `Model.fit`. 1123 workers: See `Model.fit`. 1124 use_multiprocessing: See `Model.fit`. 1125 model: The `Model` instance. Needed in order to correctly `build` the 1126 `Model` using generator-like inputs (see `GeneratorDataAdapter`). 1127 steps_per_execution: See `Model.compile`. 1128 distribute: Whether to distribute the `tf.dataset`. 1129 `PreprocessingLayer.adapt` does not support distributed datasets, 1130 `Model` should always set this to `True`. 1131 """ 1132 1133 self._initial_epoch = initial_epoch 1134 self._epochs = epochs 1135 self._insufficient_data = False 1136 self._model = model 1137 1138 # `steps_per_execution_value` is the cached initial value. 1139 # `steps_per_execution` is mutable and may be changed by the DataAdapter 1140 # to handle partial executions. 1141 if steps_per_execution is None: 1142 self._steps_per_execution = 1 1143 self._steps_per_execution_value = 1 1144 else: 1145 self._steps_per_execution = steps_per_execution 1146 self._steps_per_execution_value = steps_per_execution.numpy().item() 1147 1148 adapter_cls = select_data_adapter(x, y) 1149 self._adapter = adapter_cls( 1150 x, 1151 y, 1152 batch_size=batch_size, 1153 steps=steps_per_epoch, 1154 epochs=epochs - initial_epoch, 1155 sample_weights=sample_weight, 1156 shuffle=shuffle, 1157 max_queue_size=max_queue_size, 1158 workers=workers, 1159 use_multiprocessing=use_multiprocessing, 1160 distribution_strategy=ds_context.get_strategy(), 1161 model=model) 1162 1163 strategy = ds_context.get_strategy() 1164 1165 self._current_step = 0 1166 self._step_increment = self._steps_per_execution_value - 1 1167 self._insufficient_data = False 1168 1169 self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch, 1170 class_weight, distribute) 1171 1172 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch, 1173 class_weight, distribute): 1174 """Configure the `_dataset` and `_inferred_steps` attributes.""" 1175 del x 1176 dataset = self._adapter.get_dataset() 1177 if class_weight: 1178 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 1179 self._inferred_steps = self._infer_steps(steps_per_epoch, dataset) 1180 1181 # `PreprocessingLayer.adapt` does not currently support distributed 1182 # datasets, so we pass `distribute=False` there. 1183 if distribute and not _is_distributed_dataset(dataset): 1184 dataset = strategy.experimental_distribute_dataset(dataset) 1185 self._dataset = dataset 1186 self._validate_data_handler() 1187 1188 def enumerate_epochs(self): 1189 """Yields `(epoch, tf.data.Iterator)`.""" 1190 with self._truncate_execution_to_epoch(): 1191 data_iterator = iter(self._dataset) 1192 for epoch in range(self._initial_epoch, self._epochs): 1193 if self._insufficient_data: # Set by `catch_stop_iteration`. 1194 break 1195 if self._adapter.should_recreate_iterator(): 1196 data_iterator = iter(self._dataset) 1197 yield epoch, data_iterator 1198 self._adapter.on_epoch_end() 1199 1200 @contextlib.contextmanager 1201 def _truncate_execution_to_epoch(self): 1202 """Truncates steps per execution to at most one epoch.""" 1203 should_truncate = ( 1204 self._inferred_steps is not None and 1205 self._steps_per_execution_value > self._inferred_steps) 1206 original_value = self._steps_per_execution_value 1207 try: 1208 if should_truncate: 1209 self._steps_per_execution.assign(self._inferred_steps) 1210 self._steps_per_execution_value = self._inferred_steps 1211 yield 1212 finally: 1213 if should_truncate: 1214 self._steps_per_execution.assign(original_value) 1215 self._steps_per_execution_value = original_value 1216 1217 def sync(self): 1218 context.async_wait() 1219 1220 @contextlib.contextmanager 1221 def catch_stop_iteration(self): 1222 """Catches errors when an iterator runs out of data.""" 1223 try: 1224 yield 1225 self.sync() 1226 except (StopIteration, errors.OutOfRangeError): 1227 if self._inferred_steps is None: 1228 self._inferred_steps = self._current_step 1229 else: 1230 self._insufficient_data = True 1231 total_epochs = self._epochs - self._initial_epoch 1232 logging.warning( 1233 "Your input ran out of data; interrupting training. " 1234 "Make sure that your dataset or generator can generate at " 1235 "least `steps_per_epoch * epochs` batches (in this case, " 1236 "{} batches). You may need to use the repeat() function " 1237 "when building your dataset.".format(total_epochs * 1238 self._inferred_steps)) 1239 1240 def steps(self): 1241 """Yields steps for the current epoch.""" 1242 self._current_step = 0 1243 # `self._inferred_steps` can be changed by `catch_stop_iteration`. 1244 while (self._inferred_steps is None or 1245 self._current_step < self._inferred_steps): 1246 if self._insufficient_data: # Set by `catch_stop_iteration`. 1247 break 1248 1249 can_run_full_execution = ( 1250 self._steps_per_execution_value == 1 or 1251 self._inferred_steps is None or 1252 self._inferred_steps - self._current_step >= 1253 self._steps_per_execution_value) 1254 1255 if can_run_full_execution: 1256 self._step_increment = self._steps_per_execution_value - 1 1257 yield self._current_step 1258 self._current_step += self._steps_per_execution_value 1259 else: 1260 # Last partial execution. 1261 steps_remaining = self._inferred_steps - self._current_step 1262 self._steps_per_execution.assign(steps_remaining) 1263 self._step_increment = steps_remaining - 1 1264 yield self._current_step 1265 self._current_step += steps_remaining 1266 self._steps_per_execution.assign(self._steps_per_execution_value) 1267 1268 @property 1269 def step_increment(self): 1270 """The number to increment the step for `on_batch_end` methods.""" 1271 return self._step_increment 1272 1273 @property 1274 def inferred_steps(self): 1275 """The inferred steps per epoch of the created `Dataset`. 1276 1277 This will be `None` in the case where: 1278 1279 (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and 1280 (2) `steps_per_epoch` was not provided, and 1281 (3) The first epoch of iteration has not yet completed. 1282 1283 Returns: 1284 The inferred steps per epoch of the created `Dataset`. 1285 """ 1286 return self._inferred_steps 1287 1288 @property 1289 def should_sync(self): 1290 # Catch OutOfRangeError for Datasets of unknown size. 1291 # This blocks until the batch has finished executing. 1292 # TODO(b/150292341): Allow multiple async steps here. 1293 return self._inferred_steps is None 1294 1295 def _log_indefinite_training_warning(self): 1296 logging.warning("The training loop will run indefinitely since you have " 1297 "set `steps_per_epoch=-1`. Please use batch-level " 1298 "callbacks to save checkpoints or log training progress, " 1299 "etc") 1300 1301 def _infer_steps(self, steps, dataset): 1302 """Infers steps_per_epoch needed to loop through a dataset.""" 1303 if steps == -1: 1304 self._log_indefinite_training_warning() 1305 return None 1306 1307 if steps is not None: 1308 return steps 1309 1310 adapter_steps = self._adapter.get_size() 1311 if adapter_steps is not None: 1312 return adapter_steps 1313 1314 size = cardinality.cardinality(dataset) 1315 if size == cardinality.INFINITE and steps is None: 1316 raise ValueError( 1317 "When passing an infinitely repeating dataset, please specify a " 1318 "`steps_per_epoch` value so that epoch level " 1319 "callbacks continue to work. The value can be arbitrary, or a number " 1320 "that you think correctly defines the size of an epoch. " 1321 "Epoch-level callbacks will then be called at this interval.") 1322 if size >= 0: 1323 return size.numpy().item() 1324 return None 1325 1326 @property 1327 def _samples(self): 1328 return self._adapter.get_samples() 1329 1330 def _validate_data_handler(self): 1331 # TODO(b/152094471): Support this with DistIter.get_next_as_optional. 1332 if self._steps_per_execution_value > 1 and self._inferred_steps is None: 1333 raise ValueError( 1334 "Could not infer the size of the data. With " 1335 "`steps_per_execution > 1`, you must specify the number of steps " 1336 "to run.") 1337 1338 1339class _ClusterCoordinatorDataHandler(DataHandler): 1340 """A `DataHandler` that is compatible with `ClusterCoordinator`.""" 1341 1342 def __init__(self, x, y=None, **kwargs): 1343 if not isinstance(x, dataset_creator.DatasetCreator): 1344 x = self._convert_to_dataset_creator(x, y, **kwargs) 1345 1346 super().__init__(x=x, **kwargs) 1347 1348 def _convert_to_dataset_creator(self, x, y, **kwargs): 1349 """Converts non-tf.data.Dataset to `DatasetCreator` instances.""" 1350 1351 def _dataset_fn(input_context): 1352 del input_context 1353 data_adapter_cls = select_data_adapter(x, y) 1354 return data_adapter_cls(x=x, y=y, **kwargs).get_dataset() 1355 1356 # This check is needed because types like `tf.data.Dataset` don't work with 1357 # PSS yet. So only apply this logic to the types we can support. 1358 if (isinstance(x, _get_tensor_types()) and 1359 isinstance(y, _get_tensor_types())): 1360 return dataset_creator.DatasetCreator(_dataset_fn) 1361 else: 1362 raise NotImplementedError( 1363 "Only `tf.keras.utils.experimental.DatasetCreator`, `tf.Tensor`, " 1364 "numpy arrays and pandas dataframes are supported types at this " 1365 "time.") 1366 1367 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch, 1368 class_weight, distribute): 1369 if not isinstance(x, dataset_creator.DatasetCreator): 1370 raise TypeError("When using `ParameterServerStrategy`, `x` must be a " 1371 "`DatasetCreator`.") 1372 1373 def per_worker_dataset_fn(): 1374 1375 return strategy.distribute_datasets_from_function( 1376 x, options=x.input_options) 1377 1378 self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access 1379 per_worker_dataset_fn) 1380 1381 if steps_per_epoch == -1: 1382 self._inferred_steps = None 1383 self._log_indefinite_training_warning() 1384 else: 1385 self._inferred_steps = steps_per_epoch 1386 1387 def sync(self): 1388 self._model._cluster_coordinator.join() # pylint: disable=protected-access 1389 1390 1391def get_data_handler(*args, **kwargs): 1392 if getattr(kwargs["model"], "_cluster_coordinator", None): 1393 return _ClusterCoordinatorDataHandler(*args, **kwargs) 1394 return DataHandler(*args, **kwargs) 1395 1396 1397def _make_class_weight_map_fn(class_weight): 1398 """Applies class weighting to a `Dataset`. 1399 1400 The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where 1401 `y` must be a single `Tensor`. 1402 1403 Args: 1404 class_weight: A map where the keys are integer class ids and values are 1405 the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}` 1406 1407 Returns: 1408 A function that can be used with `tf.data.Dataset.map` to apply class 1409 weighting. 1410 """ 1411 class_ids = list(sorted(class_weight.keys())) 1412 expected_class_ids = list(range(len(class_ids))) 1413 if class_ids != expected_class_ids: 1414 error_msg = ( 1415 "Expected `class_weight` to be a dict with keys from 0 to one less " 1416 "than the number of classes, found {}").format(class_weight) 1417 raise ValueError(error_msg) 1418 1419 class_weight_tensor = ops.convert_to_tensor_v2_with_dispatch( 1420 [class_weight[int(c)] for c in class_ids]) 1421 1422 def _class_weights_map_fn(*data): 1423 """Convert `class_weight` to `sample_weight`.""" 1424 x, y, sw = unpack_x_y_sample_weight(data) 1425 1426 if nest.is_nested(y): 1427 raise ValueError( 1428 "`class_weight` is only supported for Models with a single output.") 1429 1430 if y.shape.rank > 2: 1431 raise ValueError("`class_weight` not supported for " 1432 "3+ dimensional targets.") 1433 1434 y_classes = smart_cond.smart_cond( 1435 y.shape.rank == 2 and backend.shape(y)[1] > 1, 1436 lambda: backend.argmax(y, axis=1), 1437 lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) 1438 1439 cw = array_ops.gather_v2(class_weight_tensor, y_classes) 1440 if sw is not None: 1441 cw = math_ops.cast(cw, sw.dtype) 1442 sw, cw = expand_1d((sw, cw)) 1443 # `class_weight` and `sample_weight` are multiplicative. 1444 sw = sw * cw 1445 else: 1446 sw = cw 1447 1448 return x, y, sw 1449 1450 return _class_weights_map_fn 1451 1452 1453def expand_1d(data): 1454 """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.""" 1455 1456 def _expand_single_1d_tensor(t): 1457 # Leaves `CompositeTensor`s as-is. 1458 if (isinstance(t, ops.Tensor) and 1459 isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): 1460 return array_ops.expand_dims_v2(t, axis=-1) 1461 return t 1462 1463 return nest.map_structure(_expand_single_1d_tensor, data) 1464 1465 1466def train_validation_split(arrays, validation_split): 1467 """Split arrays into train and validation subsets in deterministic order. 1468 1469 The last part of data will become validation data. 1470 1471 Args: 1472 arrays: Tensors to split. Allowed inputs are arbitrarily nested structures 1473 of Tensors and NumPy arrays. 1474 validation_split: Float between 0 and 1. The proportion of the dataset to 1475 include in the validation split. The rest of the dataset will be included 1476 in the training split. 1477 Returns: 1478 `(train_arrays, validation_arrays)` 1479 """ 1480 1481 def _can_split(t): 1482 tensor_types = _get_tensor_types() 1483 return isinstance(t, tensor_types) or t is None 1484 1485 flat_arrays = nest.flatten(arrays) 1486 unsplitable = [type(t) for t in flat_arrays if not _can_split(t)] 1487 if unsplitable: 1488 raise ValueError( 1489 "`validation_split` is only supported for Tensors or NumPy " 1490 "arrays, found following types in the input: {}".format(unsplitable)) 1491 1492 if all(t is None for t in flat_arrays): 1493 return arrays, arrays 1494 1495 first_non_none = None 1496 for t in flat_arrays: 1497 if t is not None: 1498 first_non_none = t 1499 break 1500 1501 # Assumes all arrays have the same batch shape or are `None`. 1502 batch_dim = int(first_non_none.shape[0]) 1503 split_at = int(math.floor(batch_dim * (1. - validation_split))) 1504 1505 if split_at == 0 or split_at == batch_dim: 1506 raise ValueError( 1507 "Training data contains {batch_dim} samples, which is not sufficient " 1508 "to split it into a validation and training set as specified by " 1509 "`validation_split={validation_split}`. Either provide more data, or a " 1510 "different value for the `validation_split` argument." .format( 1511 batch_dim=batch_dim, validation_split=validation_split)) 1512 1513 def _split(t, start, end): 1514 if t is None: 1515 return t 1516 return t[start:end] 1517 1518 train_arrays = nest.map_structure( 1519 functools.partial(_split, start=0, end=split_at), arrays) 1520 val_arrays = nest.map_structure( 1521 functools.partial(_split, start=split_at, end=batch_dim), arrays) 1522 1523 return train_arrays, val_arrays 1524 1525 1526@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[]) 1527def unpack_x_y_sample_weight(data): 1528 """Unpacks user-provided data tuple. 1529 1530 This is a convenience utility to be used when overriding 1531 `Model.train_step`, `Model.test_step`, or `Model.predict_step`. 1532 This utility makes it easy to support data of the form `(x,)`, 1533 `(x, y)`, or `(x, y, sample_weight)`. 1534 1535 Standalone usage: 1536 1537 >>> features_batch = tf.ones((10, 5)) 1538 >>> labels_batch = tf.zeros((10, 5)) 1539 >>> data = (features_batch, labels_batch) 1540 >>> # `y` and `sample_weight` will default to `None` if not provided. 1541 >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 1542 >>> sample_weight is None 1543 True 1544 1545 Example in overridden `Model.train_step`: 1546 1547 ```python 1548 class MyModel(tf.keras.Model): 1549 1550 def train_step(self, data): 1551 # If `sample_weight` is not provided, all samples will be weighted 1552 # equally. 1553 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 1554 1555 with tf.GradientTape() as tape: 1556 y_pred = self(x, training=True) 1557 loss = self.compiled_loss( 1558 y, y_pred, sample_weight, regularization_losses=self.losses) 1559 trainable_variables = self.trainable_variables 1560 gradients = tape.gradient(loss, trainable_variables) 1561 self.optimizer.apply_gradients(zip(gradients, trainable_variables)) 1562 1563 self.compiled_metrics.update_state(y, y_pred, sample_weight) 1564 return {m.name: m.result() for m in self.metrics} 1565 ``` 1566 1567 Args: 1568 data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. 1569 1570 Returns: 1571 The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not 1572 provided. 1573 """ 1574 if not isinstance(data, tuple): 1575 return (data, None, None) 1576 elif len(data) == 1: 1577 return (data[0], None, None) 1578 elif len(data) == 2: 1579 return (data[0], data[1], None) 1580 elif len(data) == 3: 1581 return (data[0], data[1], data[2]) 1582 else: 1583 error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, " 1584 "or `(x, y, sample_weight)`, found: {}").format(data) 1585 raise ValueError(error_msg) 1586 1587 1588@keras_export("keras.utils.pack_x_y_sample_weight", v1=[]) 1589def pack_x_y_sample_weight(x, y=None, sample_weight=None): 1590 """Packs user-provided data into a tuple. 1591 1592 This is a convenience utility for packing data into the tuple formats 1593 that `Model.fit` uses. 1594 1595 Standalone usage: 1596 1597 >>> x = tf.ones((10, 1)) 1598 >>> data = tf.keras.utils.pack_x_y_sample_weight(x) 1599 >>> isinstance(data, tf.Tensor) 1600 True 1601 >>> y = tf.ones((10, 1)) 1602 >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y) 1603 >>> isinstance(data, tuple) 1604 True 1605 >>> x, y = data 1606 1607 Args: 1608 x: Features to pass to `Model`. 1609 y: Ground-truth targets to pass to `Model`. 1610 sample_weight: Sample weight for each element. 1611 1612 Returns: 1613 Tuple in the format used in `Model.fit`. 1614 """ 1615 if y is None: 1616 # For single x-input, we do no tuple wrapping since in this case 1617 # there is no ambiguity. This also makes NumPy and Dataset 1618 # consistent in that the user does not have to wrap their Dataset 1619 # data in an unecessary tuple 1620 if not nest.is_nested(x): 1621 return x 1622 else: 1623 return (x,) 1624 elif sample_weight is None: 1625 return (x, y) 1626 else: 1627 return (x, y, sample_weight) 1628 1629 1630def single_batch_iterator(strategy, 1631 x, 1632 y=None, 1633 sample_weight=None, 1634 class_weight=None): 1635 """Creates a single-batch dataset.""" 1636 x, y, sample_weight = _process_tensorlike((x, y, sample_weight)) 1637 if y is None: 1638 data = (x,) 1639 elif sample_weight is None: 1640 data = (x, y) 1641 else: 1642 data = (x, y, sample_weight) 1643 1644 _check_data_cardinality(data) 1645 dataset = dataset_ops.DatasetV2.from_tensors(data) 1646 if class_weight: 1647 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 1648 dataset = strategy.experimental_distribute_dataset(dataset) 1649 return iter(dataset) 1650 1651 1652def _check_data_cardinality(data): 1653 num_samples = set(int(i.shape[0]) for i in nest.flatten(data)) 1654 if len(num_samples) > 1: 1655 msg = "Data cardinality is ambiguous:\n" 1656 for label, single_data in zip(["x", "y", "sample_weight"], data): 1657 msg += " {} sizes: {}\n".format( 1658 label, ", ".join(str(i.shape[0]) for i in nest.flatten(single_data))) 1659 msg += "Make sure all arrays contain the same number of samples." 1660 raise ValueError(msg) 1661 1662 1663def _get_tensor_types(): 1664 try: 1665 import pandas as pd # pylint: disable=g-import-not-at-top 1666 1667 return (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame) 1668 except ImportError: 1669 return (ops.Tensor, np.ndarray) 1670 1671 1672def _is_scipy_sparse(x): 1673 try: 1674 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 1675 1676 return issparse(x) 1677 except ImportError: 1678 return False 1679 1680 1681def _scipy_sparse_to_sparse_tensor(t): 1682 """Converts a SciPy sparse matrix to a SparseTensor.""" 1683 sparse_coo = t.tocoo() 1684 row, col = sparse_coo.row, sparse_coo.col 1685 data, shape = sparse_coo.data, sparse_coo.shape 1686 if issubclass(data.dtype.type, np.floating): 1687 data = data.astype(backend.floatx()) 1688 indices = np.concatenate( 1689 (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1) 1690 return sparse_tensor.SparseTensor(indices, data, shape) 1691 1692 1693def _is_distributed_dataset(ds): 1694 return isinstance(ds, input_lib.DistributedDatasetInterface) 1695