xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/data_adapter.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adapter module that convert different input data objects into tf.dataset."""
16
17import abc
18import contextlib
19import functools
20import itertools
21import math
22import random
23
24import numpy as np
25
26from tensorflow.python.data.experimental.ops import cardinality
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.data.ops import iterator_ops
29from tensorflow.python.data.ops import options as options_lib
30from tensorflow.python.distribute import distribution_strategy_context as ds_context
31from tensorflow.python.distribute import input_lib
32from tensorflow.python.eager import context
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import smart_cond
37from tensorflow.python.framework import sparse_tensor
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.keras import backend
40from tensorflow.python.keras.engine import training_utils
41from tensorflow.python.keras.utils import data_utils
42from tensorflow.python.keras.utils import dataset_creator
43from tensorflow.python.keras.utils import tf_utils
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import random_ops
47from tensorflow.python.ops import script_ops
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.util import nest
50from tensorflow.python.util.tf_export import keras_export
51
52
53class DataAdapter(object, metaclass=abc.ABCMeta):
54  """Base class for input data adapter.
55
56  In TF 2.0, tf.data is the preferred API for user to feed in data. In order
57  to simplify the training code path, all the input data object will be
58  converted to `tf.data.Dataset` if possible.
59
60  Note that since this class is mainly targeted for TF 2.0, it might have a lot
61  of assumptions under the hood, eg eager context by default, distribution
62  strategy, etc. In the meantime, some legacy feature support might be dropped,
63  eg, Iterator from dataset API in v1, etc.
64
65  The sample usage of this class is like:
66
67  ```
68  x = tf.data.Dataset.range(100)
69  adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter]
70  applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)]
71  if len(applicable_adapters) != 1:
72    raise ValueError("Expect only one adapter class to handle the input")
73
74  dataset = applicable_adapters[0](x).get_dataset()
75  for data in dataset:
76    # training
77  ```
78  """
79
80  @staticmethod
81  def can_handle(x, y=None):
82    """Whether the current DataAdapter could handle the input x and y.
83
84    Structure wise, x and y can be single object, or list of objects if there
85    multiple input/output, or dictionary of objects when the intput/output are
86    named.
87
88    Args:
89      x: input features.
90      y: target labels. Note that y could be None in the case of prediction.
91
92    Returns:
93      boolean
94    """
95    raise NotImplementedError
96
97  @abc.abstractmethod
98  def __init__(self, x, y=None, **kwargs):
99    """Create a DataAdapter based on data inputs.
100
101    The caller must make sure to call `can_handle()` first before invoking this
102    method. Provide unsupported data type will result into unexpected behavior.
103
104    Args:
105      x: input features.
106      y: target labels. Note that y could be None in the case of prediction.
107      **kwargs: Other keyword arguments for DataAdapter during the construction
108        of the tf.dataset.Dataset. For example:
109        - Numpy data might have `sample_weights` which will be used for
110          weighting the loss function during training.
111        - Numpy data might need to have `batch_size` parameter when constructing
112          the dataset and iterator.
113        - Certain input might need to be distribution strategy aware. When
114          `distribution_strategy` is passed, the created dataset need to respect
115          the strategy.
116        DataAdapter might choose to ignore any keyword argument if it doesn't
117        use it, or raise exception if any required argument is not provide.
118    """
119    if not self.can_handle(x, y):
120      raise ValueError("{} Cannot handle input {}, {}".format(
121          self.__class__, x, y))
122
123  @abc.abstractmethod
124  def get_dataset(self):
125    """Get a dataset instance for the current DataAdapter.
126
127    Note that the dataset returned does not repeat for epoch, so caller might
128    need to create new iterator for the same dataset at the beginning of the
129    epoch. This behavior might change in future.
130
131    Returns:
132      An tf.dataset.Dataset. Caller might use the dataset in different
133      context, eg iter(dataset) in eager to get the value directly, or in graph
134      mode, provide the iterator tensor to Keras model function.
135    """
136    raise NotImplementedError
137
138  @abc.abstractmethod
139  def get_size(self):
140    """Return the size (number of batches) for the dataset created.
141
142    For certain type of the data input, the number of batches is known, eg for
143    Numpy data, the size is same as (number_of_element / batch_size). Whereas
144    for dataset or python generator, the size is unknown since it may or may not
145    have a end state.
146
147    Returns:
148      int, the number of batches for the dataset, or None if it is unknown. The
149      caller could use this to control the loop of training, show progress bar,
150      or handle unexpected StopIteration error.
151    """
152    raise NotImplementedError
153
154  @abc.abstractmethod
155  def batch_size(self):
156    """Return the batch size of the dataset created.
157
158    For certain type of the data input, the batch size is known, and even
159    required, like numpy array. Where as for dataset, the batch is unknown
160    unless we take a peek.
161
162    Returns:
163      int, the batch size of the dataset, or None if it is unknown.
164    """
165    raise NotImplementedError
166
167  def representative_batch_size(self):
168    """Return a representative size for batches in the dataset.
169
170    This is not guaranteed to be the batch size for all batches in the
171    dataset. It just needs to be a rough approximation for batch sizes in
172    the dataset.
173
174    Returns:
175      int, a representative size for batches found in the dataset,
176      or None if it is unknown.
177    """
178    return self.batch_size()
179
180  @abc.abstractmethod
181  def has_partial_batch(self):
182    """Whether the dataset has partial batch at the end."""
183    raise NotImplementedError
184
185  @abc.abstractmethod
186  def partial_batch_size(self):
187    """The size of the final partial batch for dataset.
188
189    Will return None if has_partial_batch is False or batch_size is None.
190    """
191    raise NotImplementedError
192
193  @abc.abstractmethod
194  def should_recreate_iterator(self):
195    """Returns whether a new iterator should be created every epoch."""
196    raise NotImplementedError
197
198  def get_samples(self):
199    """Returns number of samples in the data, or `None`."""
200    if not self.get_size() or not self.batch_size():
201      return None
202    total_sample = self.get_size() * self.batch_size()
203    if self.has_partial_batch():
204      total_sample -= (self.batch_size() - self.partial_batch_size())
205    return total_sample
206
207  def on_epoch_end(self):
208    """A hook called after each epoch."""
209    pass
210
211
212class TensorLikeDataAdapter(DataAdapter):
213  """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
214
215  @staticmethod
216  def can_handle(x, y=None):
217    # TODO(kaftan): Check performance implications of using a flatten
218    #  here for other types of inputs.
219    flat_inputs = nest.flatten(x)
220    if y is not None:
221      flat_inputs += nest.flatten(y)
222
223    tensor_types = _get_tensor_types()
224
225    def _is_tensor(v):
226      if isinstance(v, tensor_types):
227        return True
228      return False
229
230    return all(_is_tensor(v) for v in flat_inputs)
231
232  def __init__(self,
233               x,
234               y=None,
235               sample_weights=None,
236               sample_weight_modes=None,
237               batch_size=None,
238               epochs=1,
239               steps=None,
240               shuffle=False,
241               **kwargs):
242    super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs)
243    x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
244    sample_weight_modes = broadcast_sample_weight_modes(
245        sample_weights, sample_weight_modes)
246
247    # If sample_weights are not specified for an output use 1.0 as weights.
248    (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
249        y, sample_weights, sample_weight_modes, check_all_flat=True)
250
251    inputs = pack_x_y_sample_weight(x, y, sample_weights)
252
253    num_samples = set(int(i.shape[0]) for i in nest.flatten(inputs)).pop()
254    _check_data_cardinality(inputs)
255
256    # If batch_size is not passed but steps is, calculate from the input data.
257    # Default to 32 for backwards compat.
258    if not batch_size:
259      batch_size = int(math.ceil(num_samples / steps)) if steps else 32
260
261    self._size = int(math.ceil(num_samples / batch_size))
262    self._batch_size = batch_size
263
264    num_full_batches = int(num_samples // batch_size)
265    self._partial_batch_size = num_samples % batch_size
266
267    if isinstance(shuffle, str):
268      shuffle = shuffle.lower()
269
270    self._shuffle = shuffle
271    # Vectorized version of shuffle.
272    # This is a performance improvement over using `from_tensor_slices`.
273    # The indices of the data are shuffled and batched, and these indices
274    # are then zipped with the data and used to extract a batch of the data
275    # at each step. The performance improvements here come from:
276    # 1. vectorized batch using gather
277    # 2. parallelized map
278    # 3. pipelined permutation generation
279    # 4. optimized permutation batching
280    # 5. disabled static optimizations
281
282    indices_dataset = dataset_ops.DatasetV2.range(1)
283    if shuffle != "batch":
284      indices_dataset = indices_dataset.repeat(epochs)
285
286    def permutation(_):
287      # It turns out to be more performant to make a new set of indices rather
288      # than reusing the same range Tensor. (presumably because of buffer
289      # forwarding.)
290      indices = math_ops.range(num_samples, dtype=dtypes.int64)
291      if shuffle and shuffle != "batch":
292        indices = random_ops.random_shuffle(indices)
293      return indices
294
295    # We prefetch a single element. Computing large permutations can take quite
296    # a while so we don't want to wait for prefetching over an epoch boundary to
297    # trigger the next permutation. On the other hand, too many simultaneous
298    # shuffles can contend on a hardware level and degrade all performance.
299    indices_dataset = indices_dataset.map(permutation).prefetch(1)
300
301    def slice_batch_indices(indices):
302      """Convert a Tensor of indices into a dataset of batched indices.
303
304      This step can be accomplished in several ways. The most natural is to
305      slice the Tensor in a Dataset map. (With a condition on the upper index to
306      handle the partial batch.) However it turns out that coercing the Tensor
307      into a shape which is divisible by the batch size (and handling the last
308      partial batch separately) allows for a much more favorable memory access
309      pattern and improved performance.
310
311      Args:
312        indices: Tensor which determines the data order for an entire epoch.
313
314      Returns:
315        A Dataset of batched indices.
316      """
317      num_in_full_batch = num_full_batches * batch_size
318      first_k_indices = array_ops.slice(indices, [0], [num_in_full_batch])
319      first_k_indices = array_ops.reshape(
320          first_k_indices, [num_full_batches, batch_size])
321
322      flat_dataset = dataset_ops.DatasetV2.from_tensor_slices(first_k_indices)
323      if self._partial_batch_size:
324        index_remainder = dataset_ops.DatasetV2.from_tensors(array_ops.slice(
325            indices, [num_in_full_batch], [self._partial_batch_size]))
326        flat_dataset = flat_dataset.concatenate(index_remainder)
327
328      if shuffle == "batch":
329        # 1024 is a magic constant that has not been properly evaluated
330        flat_dataset = flat_dataset.shuffle(1024).repeat(epochs)
331      return flat_dataset
332
333    indices_dataset = indices_dataset.flat_map(slice_batch_indices)
334
335    dataset = self.slice_inputs(indices_dataset, inputs)
336
337    if shuffle == "batch":
338      def shuffle_batch(*batch):
339        return nest.map_structure(random_ops.random_shuffle, batch)
340      dataset = dataset.map(shuffle_batch)
341
342    self._dataset = dataset
343
344  def slice_inputs(self, indices_dataset, inputs):
345    """Slice inputs into a Dataset of batches.
346
347    Given a Dataset of batch indices and the unsliced inputs,
348    this step slices the inputs in a parallelized fashion
349    and produces a dataset of input batches.
350
351    Args:
352      indices_dataset: A Dataset of batched indices
353      inputs: A python data structure that contains the inputs, targets,
354        and possibly sample weights.
355
356    Returns:
357      A Dataset of input batches matching the batch indices.
358    """
359    dataset = dataset_ops.DatasetV2.zip((
360        indices_dataset,
361        dataset_ops.DatasetV2.from_tensors(inputs).repeat()
362    ))
363
364    def grab_batch(i, data):
365      return nest.map_structure(lambda d: array_ops.gather(d, i, axis=0), data)
366
367    dataset = dataset.map(
368        grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
369
370    # Default optimizations are disabled to avoid the overhead of (unnecessary)
371    # input pipeline graph serialization and deserialization
372    options = options_lib.Options()
373    options.experimental_optimization.apply_default_optimizations = False
374    if self._shuffle:
375      # See b/141490660 for more details.
376      options.experimental_external_state_policy = (
377          options_lib.ExternalStatePolicy.IGNORE)
378    dataset = dataset.with_options(options)
379    return dataset
380
381  def get_dataset(self):
382    return self._dataset
383
384  def get_size(self):
385    return self._size
386
387  def batch_size(self):
388    return self._batch_size
389
390  def has_partial_batch(self):
391    return self._partial_batch_size > 0
392
393  def partial_batch_size(self):
394    return self._partial_batch_size or None
395
396  def should_recreate_iterator(self):
397    # An infinite dataset is always created here.
398    return False
399
400
401class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
402  """Adapter that handles array-like data without forcing it into memory.
403
404  This adapter handles array-like datasets that may be too big to fully
405  fit into memory.
406
407  Specifically, this adapter handles any Python class which implements:
408  `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings
409  as Numpy, but it ignores any case where all the inputs are Tensors or Numpy
410  arrays (because that case is handled by the base TensorLikeDataAdapter).
411
412  It ignores scipy sparse matrices and Composite Tensors because those are
413  handled by the CompositeTensorDataAdapter.
414
415  It also does not handle lists/tuples of scalars, because those are handled
416  by the ListsOfScalarsDataAdapter.
417  """
418
419  @staticmethod
420  def can_handle(x, y=None):
421    flat_inputs = nest.flatten(x)
422    if y is not None:
423      flat_inputs += nest.flatten(y)
424
425    def _is_array_like(v):
426      """Return True if v is a Tensor, array, or is array-like."""
427      return (
428          hasattr(v, "__getitem__") and
429          hasattr(v, "shape") and
430          hasattr(v, "dtype") and
431          hasattr(v, "__len__")
432      )
433
434    if (not TensorLikeDataAdapter.can_handle(x, y) and
435        not CompositeTensorDataAdapter.can_handle(x, y)):
436      return all(_is_array_like(v) for v in flat_inputs)
437    else:
438      return False
439
440  def __init__(self, *args, **kwargs):
441    logging.warning(
442        "Keras is training/fitting/evaluating on array-like data. Keras may "
443        "not be optimized for this format, so if your input data format is "
444        "supported by TensorFlow I/O (https://github.com/tensorflow/io) we "
445        "recommend using that to load a Dataset instead.")
446
447    super(GenericArrayLikeDataAdapter, self).__init__(*args, **kwargs)
448
449  def slice_inputs(self, indices_dataset, inputs):
450    """Slice inputs into a Dataset of batches.
451
452    Given a Dataset of batch indices and the unsliced inputs,
453    this step slices the inputs in a parallelized fashion
454    and produces a dataset of input batches.
455
456    Args:
457      indices_dataset: A Dataset of batched indices
458      inputs: A python data structure that contains the inputs, targets,
459        and possibly sample weights.
460
461    Returns:
462      A Dataset of input batches matching the batch indices.
463    """
464    flat_inputs = nest.flatten(inputs)
465    def dynamic_shape_like(t):
466      shape = list(t.shape)
467      shape[0] = None
468      return tuple(shape)
469
470    flat_dtypes = [inp.dtype for inp in flat_inputs]
471    contiguous = True
472    if self._shuffle and self._shuffle != "batch":
473      contiguous = False
474
475    def grab_batch(indices):
476      """Grab a batch of data from the inputs."""
477      # This uses a py_function to avoid converting the array-like
478      # into a Tensor before slicing it, because converting the array-like
479      # to a Tensor may force it into memory..
480      def py_method(ind):
481        def slice_array(data):
482          return training_utils.slice_arrays(data, ind.numpy(),
483                                             contiguous=contiguous)
484        return [slice_array(inp) for inp in flat_inputs]
485
486      flat_out = script_ops.eager_py_func(py_method, [indices], flat_dtypes)
487      for v, original_inp in zip(flat_out, flat_inputs):
488        v.set_shape(dynamic_shape_like(original_inp))
489      return nest.pack_sequence_as(inputs, flat_out)
490
491    dataset = indices_dataset.map(
492        grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
493
494    return dataset
495
496
497class DatasetCreatorAdapter(DataAdapter):
498  """Adapter that handles dataset functions."""
499
500  def __init__(self, x, y, steps=None, distribution_strategy=None, **kwargs):
501    super(DatasetCreatorAdapter, self).__init__(x, **kwargs)
502
503    if not isinstance(x, dataset_creator.DatasetCreator):
504      raise TypeError("The input of a `DatasetCreatorAdapter` should be a "
505                      "`DatasetCreator` but it received type {}.".format(
506                          type(x)))
507    if steps is None:
508      raise ValueError("When using a "
509                       "`tf.keras.utils.experimental.DatasetCreator`, "
510                       "`steps_per_epoch`, `validation_steps` or `steps` "
511                       "argument must be provided in `Model.fit`, "
512                       "`Model.evaluate`, or `Model.predict`.")
513    self.dataset_creator = x
514    self.steps = steps
515    self.strategy = distribution_strategy
516
517  @staticmethod
518  def can_handle(x, y=None):
519    if isinstance(x, dataset_creator.DatasetCreator):
520      assert y is None
521      return True
522
523  def should_recreate_iterator(self):
524    # We expect users to shuffle the dataset in their `dataset_fn` supplied to
525    # `DatasetCreator`. Since that is a buffered shuffle, we intend to not reset
526    # the dataset so the batches that are not shuffled can still be pulled.
527    return False
528
529  def get_size(self):
530    return None  # To be inferred by `DataHandler`.
531
532  def get_dataset(self):
533    return self.strategy.distribute_datasets_from_function(
534        self.dataset_creator, options=self.dataset_creator.input_options)
535
536  def batch_size(self):
537    raise NotImplementedError()
538
539  def has_partial_batch(self):
540    raise NotImplementedError()
541
542  def partial_batch_size(self):
543    raise NotImplementedError()
544
545
546class CompositeTensorDataAdapter(DataAdapter):
547  """Adapter that handles composite tensor."""
548
549  @staticmethod
550  def can_handle(x, y=None):
551    flat_inputs = nest.flatten(x)
552    if y is not None:
553      flat_inputs += nest.flatten(y)
554
555    def _is_composite(v):
556      # Dataset/iterator/DistributedDataset inherits from CompositeTensor but
557      # should be handled by DatasetAdapter and GeneratorAdapter.
558      if (tf_utils.is_extension_type(v) and
559          not isinstance(v,
560                         (dataset_ops.DatasetV2, iterator_ops.IteratorBase)) and
561          not _is_distributed_dataset(v)):
562        return True
563      # Support Scipy sparse tensors if scipy is installed
564      return _is_scipy_sparse(v)
565
566    def _is_tensor_or_composite(v):
567      if isinstance(v, (ops.Tensor, np.ndarray)):
568        return True
569      return _is_composite(v)
570
571    return (any(_is_composite(v) for v in flat_inputs) and
572            all(_is_tensor_or_composite(v) for v in flat_inputs))
573
574  def __init__(self,
575               x,
576               y=None,
577               sample_weights=None,
578               sample_weight_modes=None,
579               batch_size=None,
580               steps=None,
581               shuffle=False,
582               **kwargs):
583    super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs)
584    x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
585    sample_weight_modes = broadcast_sample_weight_modes(
586        sample_weights, sample_weight_modes)
587
588    # If sample_weights are not specified for an output use 1.0 as weights.
589    (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
590        y, sample_weights, sample_weight_modes, check_all_flat=True)
591
592    inputs = pack_x_y_sample_weight(x, y, sample_weights)
593
594    dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs)
595    num_samples = int(nest.flatten(x)[0].shape[0])
596    if shuffle:
597      dataset = dataset.shuffle(num_samples)
598
599    # If batch_size is not passed but steps is, calculate from the input data.
600    # Default to 32 for backwards compat.
601    if not batch_size:
602      batch_size = int(math.ceil(num_samples / steps)) if steps else 32
603
604    dataset = dataset.batch(batch_size)
605    self._size = int(math.ceil(num_samples / batch_size))
606    self._batch_size = batch_size
607    self._has_partial_batch = (self._size != (num_samples // batch_size))
608
609    self._partial_batch_size = None
610    if self._has_partial_batch:
611      self._partial_batch_size = (
612          num_samples - (self._size - 1) * self._batch_size)
613
614    self._dataset = dataset
615
616  def get_dataset(self):
617    return self._dataset
618
619  def get_size(self):
620    return self._size
621
622  def batch_size(self):
623    return self._batch_size
624
625  def has_partial_batch(self):
626    return self._has_partial_batch
627
628  def partial_batch_size(self):
629    return self._partial_batch_size
630
631  def should_recreate_iterator(self):
632    return True
633
634
635class ListsOfScalarsDataAdapter(DataAdapter):
636  """Adapter that handles lists of scalars and lists of lists of scalars."""
637
638  @staticmethod
639  def can_handle(x, y=None):
640    handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x)
641    handles_y = True
642    if y is not None:
643      handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y)
644    return handles_x and handles_y
645
646  @staticmethod
647  def _is_list_of_scalars(inp):
648    if isinstance(inp, (float, int, str, bytes, bytearray)):
649      return True
650    if isinstance(inp, (list, tuple)) and inp:
651      return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0])
652    return False
653
654  def __init__(self,
655               x,
656               y=None,
657               sample_weights=None,
658               sample_weight_modes=None,
659               batch_size=None,
660               shuffle=False,
661               **kwargs):
662    super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs)
663    x = np.asarray(x)
664    if y is not None:
665      y = np.asarray(y)
666    if sample_weights is not None:
667      sample_weights = np.asarray(sample_weights)
668    sample_weight_modes = broadcast_sample_weight_modes(
669        sample_weights, sample_weight_modes)
670
671    self._internal_adapter = TensorLikeDataAdapter(
672        x,
673        y=y,
674        sample_weights=sample_weights,
675        sample_weight_modes=sample_weight_modes,
676        batch_size=batch_size,
677        shuffle=shuffle,
678        **kwargs)
679
680  def get_dataset(self):
681    return self._internal_adapter.get_dataset()
682
683  def get_size(self):
684    return self._internal_adapter.get_size()
685
686  def batch_size(self):
687    return self._internal_adapter.batch_size()
688
689  def has_partial_batch(self):
690    return self._internal_adapter.has_partial_batch()
691
692  def partial_batch_size(self):
693    return self._internal_adapter.partial_batch_size()
694
695  def should_recreate_iterator(self):
696    return True
697
698
699class DatasetAdapter(DataAdapter):
700  """Adapter that handles `tf.data.Dataset`."""
701
702  @staticmethod
703  def can_handle(x, y=None):
704    return (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) or
705            _is_distributed_dataset(x))
706
707  def __init__(self,
708               x,
709               y=None,
710               sample_weights=None,
711               steps=None,
712               **kwargs):
713    super(DatasetAdapter, self).__init__(x, y, **kwargs)
714    # Note that the dataset instance is immutable, its fine to reuse the user
715    # provided dataset.
716    self._dataset = x
717
718    # The user-provided steps.
719    self._user_steps = steps
720
721    self._validate_args(y, sample_weights, steps)
722
723  def get_dataset(self):
724    return self._dataset
725
726  def get_size(self):
727    return  # Inferred in `DataHandler`.
728
729  def batch_size(self):
730    return None
731
732  def has_partial_batch(self):
733    return False
734
735  def partial_batch_size(self):
736    return None
737
738  def should_recreate_iterator(self):
739    # Since DistributedDatasets have no cardinality, the user must provide
740    # all steps that need to be run, calling `.repeat()` as needed.
741    if _is_distributed_dataset(self._dataset):
742      return False
743
744    # If user doesn't supply `steps`, or if they supply `steps` that
745    # exactly equals the size of the `Dataset`, create a new iterator
746    # each epoch.
747    return (self._user_steps is None or
748            cardinality.cardinality(self._dataset).numpy() == self._user_steps)
749
750  def _validate_args(self, y, sample_weights, steps):
751    """Validates `__init__` arguments."""
752    # Arguments that shouldn't be passed.
753    if not is_none_or_empty(y):
754      raise ValueError("`y` argument is not supported when using "
755                       "dataset as input.")
756    if not is_none_or_empty(sample_weights):
757      raise ValueError("`sample_weight` argument is not supported when using "
758                       "dataset as input.")
759
760    if steps is None:
761      if _is_distributed_dataset(self._dataset):
762        raise ValueError("When providing a distributed dataset, you must "
763                         "specify the number of steps to run.")
764
765      size = cardinality.cardinality(self._dataset).numpy()
766      if size == cardinality.INFINITE and steps is None:
767        raise ValueError(
768            "When providing an infinite dataset, you must specify "
769            "the number of steps to run (if you did not intend to "
770            "create an infinite dataset, make sure to not call "
771            "`repeat()` on the dataset).")
772
773
774class GeneratorDataAdapter(DataAdapter):
775  """Adapter that handles python generators and iterators."""
776
777  @staticmethod
778  def can_handle(x, y=None):
779    return ((hasattr(x, "__next__") or hasattr(x, "next"))
780            and hasattr(x, "__iter__")
781            and not isinstance(x, data_utils.Sequence))
782
783  def __init__(self,
784               x,
785               y=None,
786               sample_weights=None,
787               workers=1,
788               use_multiprocessing=False,
789               max_queue_size=10,
790               model=None,
791               **kwargs):
792    # Generators should never shuffle as exhausting the generator in order to
793    # shuffle the batches is inefficient.
794    kwargs.pop("shuffle", None)
795
796    if not is_none_or_empty(y):
797      raise ValueError("`y` argument is not supported when using "
798                       "python generator as input.")
799    if not is_none_or_empty(sample_weights):
800      raise ValueError("`sample_weight` argument is not supported when using "
801                       "python generator as input.")
802
803    super(GeneratorDataAdapter, self).__init__(x, y, **kwargs)
804
805    # Since we have to know the dtype of the python generator when we build the
806    # dataset, we have to look at a batch to infer the structure.
807    peek, x = self._peek_and_restore(x)
808    peek = self._standardize_batch(peek)
809    peek = _process_tensorlike(peek)
810
811    # Need to build the Model on concrete input shapes.
812    if model is not None and not model.built:
813      concrete_x, _, _ = unpack_x_y_sample_weight(peek)
814      model.distribute_strategy.run(
815          lambda x: model(x, training=False), args=(concrete_x,))
816
817    self._first_batch_size = int(nest.flatten(peek)[0].shape[0])
818
819    def _get_dynamic_shape(t):
820      shape = t.shape
821      # Unknown number of dimensions, `as_list` cannot be called.
822      if shape.rank is None:
823        return shape
824      return tensor_shape.TensorShape([None for _ in shape.as_list()])
825
826    output_shapes = nest.map_structure(_get_dynamic_shape, peek)
827    output_types = nest.map_structure(lambda t: t.dtype, peek)
828
829    # Note that dataset API takes a callable that creates a generator object,
830    # rather than generator itself, which is why we define a function here.
831    generator_fn = self._handle_multiprocessing(x, workers, use_multiprocessing,
832                                                max_queue_size)
833
834    def wrapped_generator():
835      for data in generator_fn():
836        yield self._standardize_batch(data)
837
838    dataset = dataset_ops.DatasetV2.from_generator(
839        wrapped_generator, output_types, output_shapes=output_shapes)
840
841    if workers == 1 and not use_multiprocessing:
842      dataset = dataset.prefetch(1)
843
844    self._dataset = dataset
845
846  def _standardize_batch(self, data):
847    """Standardizes a batch output by a generator."""
848    # Removes `None`s.
849    x, y, sample_weight = unpack_x_y_sample_weight(data)
850    data = pack_x_y_sample_weight(x, y, sample_weight)
851
852    data = nest.list_to_tuple(data)
853
854    def _convert_dtype(t):
855      if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)):
856        return np.array(t, dtype=backend.floatx())
857      return t
858
859    data = nest.map_structure(_convert_dtype, data)
860    return data
861
862  @staticmethod
863  def _peek_and_restore(x):
864    peek = next(x)
865    return peek, itertools.chain([peek], x)
866
867  def _handle_multiprocessing(self, x, workers, use_multiprocessing,
868                              max_queue_size):
869    """Create a callable, possibly including an Enqueuer."""
870    if workers > 1 or (workers > 0 and use_multiprocessing):
871      def generator_fn():
872        enqueuer = data_utils.GeneratorEnqueuer(
873            x, use_multiprocessing=use_multiprocessing)
874        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
875        return enqueuer.get()
876    else:
877      generator_fn = lambda: x
878    return generator_fn
879
880  def get_dataset(self):
881    return self._dataset
882
883  def get_size(self):
884    return None
885
886  def batch_size(self):
887    return None
888
889  def representative_batch_size(self):
890    return self._first_batch_size
891
892  def has_partial_batch(self):
893    return False
894
895  def partial_batch_size(self):
896    return
897
898  def should_recreate_iterator(self):
899    return False
900
901
902class KerasSequenceAdapter(GeneratorDataAdapter):
903  """Adapter that handles `keras.utils.Sequence`."""
904
905  @staticmethod
906  def can_handle(x, y=None):
907    return isinstance(x, data_utils.Sequence)
908
909  def __init__(self,
910               x,
911               y=None,
912               sample_weights=None,
913               shuffle=False,
914               workers=1,
915               use_multiprocessing=False,
916               max_queue_size=10,
917               model=None,
918               **kwargs):
919    if not is_none_or_empty(y):
920      raise ValueError("`y` argument is not supported when using "
921                       "`keras.utils.Sequence` as input.")
922    if not is_none_or_empty(sample_weights):
923      raise ValueError("`sample_weight` argument is not supported when using "
924                       "`keras.utils.Sequence` as input.")
925
926    self._size = len(x)
927    self._shuffle_sequence = shuffle
928    self._keras_sequence = x
929    self._enqueuer = None
930    super(KerasSequenceAdapter, self).__init__(
931        x,
932        shuffle=False,  # Shuffle is handed in the _make_callable override.
933        workers=workers,
934        use_multiprocessing=use_multiprocessing,
935        max_queue_size=max_queue_size,
936        model=model,
937        **kwargs)
938
939  @staticmethod
940  def _peek_and_restore(x):
941    return x[0], x
942
943  def _handle_multiprocessing(self, x, workers, use_multiprocessing,
944                              max_queue_size):
945    if workers > 1 or (workers > 0 and use_multiprocessing):
946      def generator_fn():
947        self._enqueuer = data_utils.OrderedEnqueuer(
948            x, use_multiprocessing=use_multiprocessing,
949            shuffle=self._shuffle_sequence)
950        self._enqueuer.start(workers=workers, max_queue_size=max_queue_size)
951        return self._enqueuer.get()
952    else:
953      def generator_fn():
954        order = range(len(x))
955        if self._shuffle_sequence:
956          # Match the shuffle convention in OrderedEnqueuer.
957          order = list(order)
958          random.shuffle(order)
959
960        for i in order:
961          yield x[i]
962
963    return generator_fn
964
965  def get_size(self):
966    return self._size
967
968  def should_recreate_iterator(self):
969    return True
970
971  def on_epoch_end(self):
972    if self._enqueuer:
973      self._enqueuer.stop()
974    self._keras_sequence.on_epoch_end()
975
976
977ALL_ADAPTER_CLS = [
978    ListsOfScalarsDataAdapter, TensorLikeDataAdapter,
979    GenericArrayLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,
980    KerasSequenceAdapter, CompositeTensorDataAdapter, DatasetCreatorAdapter
981]
982
983
984def select_data_adapter(x, y):
985  """Selects a data adapter than can handle a given x and y."""
986  adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)]
987  if not adapter_cls:
988    # TODO(scottzhu): This should be a less implementation-specific error.
989    raise ValueError(
990        "Failed to find data adapter that can handle "
991        "input: {}, {}".format(
992            _type_name(x), _type_name(y)))
993  elif len(adapter_cls) > 1:
994    raise RuntimeError(
995        "Data adapters should be mutually exclusive for "
996        "handling inputs. Found multiple adapters {} to handle "
997        "input: {}, {}".format(
998            adapter_cls, _type_name(x), _type_name(y)))
999  return adapter_cls[0]
1000
1001
1002def _type_name(x):
1003  """Generates a description of the type of an object."""
1004  if isinstance(x, dict):
1005    key_types = set(_type_name(key) for key in x.keys())
1006    val_types = set(_type_name(key) for key in x.values())
1007    return "({} containing {} keys and {} values)".format(
1008        type(x), key_types, val_types)
1009  if isinstance(x, (list, tuple)):
1010    types = set(_type_name(val) for val in x)
1011    return "({} containing values of types {})".format(
1012        type(x), types)
1013  return str(type(x))
1014
1015
1016def _process_tensorlike(inputs):
1017  """Process tensor-like inputs.
1018
1019  This function:
1020
1021  (1) Converts `Numpy` arrays to `Tensor`s.
1022  (2) Converts `Scipy` sparse matrices to `SparseTensor`s.
1023  (2) Converts `list`s to `tuple`s (for `tf.data` support).
1024
1025  Args:
1026    inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
1027
1028  Returns:
1029    Structure of `Tensor`s or tensor-like.
1030  """
1031
1032  def _convert_numpy_and_scipy(x):
1033    if isinstance(x, np.ndarray):
1034      dtype = None
1035      if issubclass(x.dtype.type, np.floating):
1036        dtype = backend.floatx()
1037      return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
1038    elif _is_scipy_sparse(x):
1039      return _scipy_sparse_to_sparse_tensor(x)
1040    return x
1041
1042  inputs = nest.map_structure(_convert_numpy_and_scipy, inputs)
1043  return nest.list_to_tuple(inputs)
1044
1045
1046def is_none_or_empty(inputs):
1047  # util method to check if the input is a None or a empty list.
1048  # the python "not" check will raise an error like below if the input is a
1049  # numpy array
1050  # "The truth value of an array with more than one element is ambiguous.
1051  # Use a.any() or a.all()"
1052  return inputs is None or not nest.flatten(inputs)
1053
1054
1055def broadcast_sample_weight_modes(target_structure, sample_weight_modes):
1056  """Match sample_weight_modes structure with output structure."""
1057  if target_structure is None or not nest.flatten(target_structure):
1058    return sample_weight_modes
1059
1060  if isinstance(sample_weight_modes, str):
1061    if isinstance(target_structure, dict):
1062      return {key: sample_weight_modes for key in target_structure.keys()}
1063    return [sample_weight_modes for _ in target_structure]
1064
1065  if sample_weight_modes:
1066    try:
1067      nest.assert_same_structure(
1068          training_utils.list_to_tuple(target_structure),
1069          training_utils.list_to_tuple(sample_weight_modes))
1070    except (ValueError, TypeError):
1071      target_str = str(nest.map_structure(lambda _: "...", target_structure))
1072      mode_str = str(nest.map_structure(lambda _: "...", sample_weight_modes))
1073
1074      # Attempt to coerce sample_weight_modes to the target structure. This
1075      # implicitly depends on the fact that Model flattens outputs for its
1076      # internal representation.
1077      try:
1078        sample_weight_modes = nest.pack_sequence_as(
1079            target_structure, nest.flatten(sample_weight_modes))
1080        logging.warning(
1081            "sample_weight modes were coerced from\n  {}\n    to  \n  {}"
1082            .format(target_str, mode_str))
1083      except (ValueError, TypeError):
1084        raise ValueError(
1085            "Unable to match target structure and sample_weight_modes "
1086            "structure:\n  {}\n    to  \n  {}".format(target_str, mode_str))
1087
1088  return sample_weight_modes
1089
1090
1091class DataHandler(object):
1092  """Handles iterating over epoch-level `tf.data.Iterator` objects."""
1093
1094  def __init__(self,
1095               x,
1096               y=None,
1097               sample_weight=None,
1098               batch_size=None,
1099               steps_per_epoch=None,
1100               initial_epoch=0,
1101               epochs=1,
1102               shuffle=False,
1103               class_weight=None,
1104               max_queue_size=10,
1105               workers=1,
1106               use_multiprocessing=False,
1107               model=None,
1108               steps_per_execution=None,
1109               distribute=True):
1110    """Initializes a `DataHandler`.
1111
1112    Arguments:
1113      x: See `Model.fit`.
1114      y: See `Model.fit`.
1115      sample_weight: See `Model.fit`.
1116      batch_size: See `Model.fit`.
1117      steps_per_epoch: See `Model.fit`.
1118      initial_epoch: See `Model.fit`.
1119      epochs: See `Model.fit`.
1120      shuffle: See `Model.fit`.
1121      class_weight: See `Model.fit`.
1122      max_queue_size: See `Model.fit`.
1123      workers: See `Model.fit`.
1124      use_multiprocessing: See `Model.fit`.
1125      model: The `Model` instance. Needed in order to correctly `build` the
1126        `Model` using generator-like inputs (see `GeneratorDataAdapter`).
1127      steps_per_execution: See `Model.compile`.
1128      distribute: Whether to distribute the `tf.dataset`.
1129        `PreprocessingLayer.adapt` does not support distributed datasets,
1130        `Model` should always set this to `True`.
1131    """
1132
1133    self._initial_epoch = initial_epoch
1134    self._epochs = epochs
1135    self._insufficient_data = False
1136    self._model = model
1137
1138    # `steps_per_execution_value` is the cached initial value.
1139    # `steps_per_execution` is mutable and may be changed by the DataAdapter
1140    # to handle partial executions.
1141    if steps_per_execution is None:
1142      self._steps_per_execution = 1
1143      self._steps_per_execution_value = 1
1144    else:
1145      self._steps_per_execution = steps_per_execution
1146      self._steps_per_execution_value = steps_per_execution.numpy().item()
1147
1148    adapter_cls = select_data_adapter(x, y)
1149    self._adapter = adapter_cls(
1150        x,
1151        y,
1152        batch_size=batch_size,
1153        steps=steps_per_epoch,
1154        epochs=epochs - initial_epoch,
1155        sample_weights=sample_weight,
1156        shuffle=shuffle,
1157        max_queue_size=max_queue_size,
1158        workers=workers,
1159        use_multiprocessing=use_multiprocessing,
1160        distribution_strategy=ds_context.get_strategy(),
1161        model=model)
1162
1163    strategy = ds_context.get_strategy()
1164
1165    self._current_step = 0
1166    self._step_increment = self._steps_per_execution_value - 1
1167    self._insufficient_data = False
1168
1169    self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch,
1170                                               class_weight, distribute)
1171
1172  def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
1173                                            class_weight, distribute):
1174    """Configure the `_dataset` and `_inferred_steps` attributes."""
1175    del x
1176    dataset = self._adapter.get_dataset()
1177    if class_weight:
1178      dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1179    self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
1180
1181    # `PreprocessingLayer.adapt` does not currently support distributed
1182    # datasets, so we pass `distribute=False` there.
1183    if distribute and not _is_distributed_dataset(dataset):
1184      dataset = strategy.experimental_distribute_dataset(dataset)
1185    self._dataset = dataset
1186    self._validate_data_handler()
1187
1188  def enumerate_epochs(self):
1189    """Yields `(epoch, tf.data.Iterator)`."""
1190    with self._truncate_execution_to_epoch():
1191      data_iterator = iter(self._dataset)
1192      for epoch in range(self._initial_epoch, self._epochs):
1193        if self._insufficient_data:  # Set by `catch_stop_iteration`.
1194          break
1195        if self._adapter.should_recreate_iterator():
1196          data_iterator = iter(self._dataset)
1197        yield epoch, data_iterator
1198        self._adapter.on_epoch_end()
1199
1200  @contextlib.contextmanager
1201  def _truncate_execution_to_epoch(self):
1202    """Truncates steps per execution to at most one epoch."""
1203    should_truncate = (
1204        self._inferred_steps is not None and
1205        self._steps_per_execution_value > self._inferred_steps)
1206    original_value = self._steps_per_execution_value
1207    try:
1208      if should_truncate:
1209        self._steps_per_execution.assign(self._inferred_steps)
1210        self._steps_per_execution_value = self._inferred_steps
1211      yield
1212    finally:
1213      if should_truncate:
1214        self._steps_per_execution.assign(original_value)
1215        self._steps_per_execution_value = original_value
1216
1217  def sync(self):
1218    context.async_wait()
1219
1220  @contextlib.contextmanager
1221  def catch_stop_iteration(self):
1222    """Catches errors when an iterator runs out of data."""
1223    try:
1224      yield
1225      self.sync()
1226    except (StopIteration, errors.OutOfRangeError):
1227      if self._inferred_steps is None:
1228        self._inferred_steps = self._current_step
1229      else:
1230        self._insufficient_data = True
1231        total_epochs = self._epochs - self._initial_epoch
1232        logging.warning(
1233            "Your input ran out of data; interrupting training. "
1234            "Make sure that your dataset or generator can generate at "
1235            "least `steps_per_epoch * epochs` batches (in this case, "
1236            "{} batches). You may need to use the repeat() function "
1237            "when building your dataset.".format(total_epochs *
1238                                                 self._inferred_steps))
1239
1240  def steps(self):
1241    """Yields steps for the current epoch."""
1242    self._current_step = 0
1243    # `self._inferred_steps` can be changed by `catch_stop_iteration`.
1244    while (self._inferred_steps is None or
1245           self._current_step < self._inferred_steps):
1246      if self._insufficient_data:  # Set by `catch_stop_iteration`.
1247        break
1248
1249      can_run_full_execution = (
1250          self._steps_per_execution_value == 1 or
1251          self._inferred_steps is None or
1252          self._inferred_steps - self._current_step >=
1253          self._steps_per_execution_value)
1254
1255      if can_run_full_execution:
1256        self._step_increment = self._steps_per_execution_value - 1
1257        yield self._current_step
1258        self._current_step += self._steps_per_execution_value
1259      else:
1260        # Last partial execution.
1261        steps_remaining = self._inferred_steps - self._current_step
1262        self._steps_per_execution.assign(steps_remaining)
1263        self._step_increment = steps_remaining - 1
1264        yield self._current_step
1265        self._current_step += steps_remaining
1266        self._steps_per_execution.assign(self._steps_per_execution_value)
1267
1268  @property
1269  def step_increment(self):
1270    """The number to increment the step for `on_batch_end` methods."""
1271    return self._step_increment
1272
1273  @property
1274  def inferred_steps(self):
1275    """The inferred steps per epoch of the created `Dataset`.
1276
1277    This will be `None` in the case where:
1278
1279    (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and
1280    (2) `steps_per_epoch` was not provided, and
1281    (3) The first epoch of iteration has not yet completed.
1282
1283    Returns:
1284      The inferred steps per epoch of the created `Dataset`.
1285    """
1286    return self._inferred_steps
1287
1288  @property
1289  def should_sync(self):
1290    # Catch OutOfRangeError for Datasets of unknown size.
1291    # This blocks until the batch has finished executing.
1292    # TODO(b/150292341): Allow multiple async steps here.
1293    return self._inferred_steps is None
1294
1295  def _log_indefinite_training_warning(self):
1296    logging.warning("The training loop will run indefinitely since you have "
1297                    "set `steps_per_epoch=-1`. Please use batch-level "
1298                    "callbacks to save checkpoints or log training progress, "
1299                    "etc")
1300
1301  def _infer_steps(self, steps, dataset):
1302    """Infers steps_per_epoch needed to loop through a dataset."""
1303    if steps == -1:
1304      self._log_indefinite_training_warning()
1305      return None
1306
1307    if steps is not None:
1308      return steps
1309
1310    adapter_steps = self._adapter.get_size()
1311    if adapter_steps is not None:
1312      return adapter_steps
1313
1314    size = cardinality.cardinality(dataset)
1315    if size == cardinality.INFINITE and steps is None:
1316      raise ValueError(
1317          "When passing an infinitely repeating dataset, please specify a "
1318          "`steps_per_epoch` value so that epoch level "
1319          "callbacks continue to work. The value can be arbitrary, or a number "
1320          "that you think correctly defines the size of an epoch. "
1321          "Epoch-level callbacks will then be called at this interval.")
1322    if size >= 0:
1323      return size.numpy().item()
1324    return None
1325
1326  @property
1327  def _samples(self):
1328    return self._adapter.get_samples()
1329
1330  def _validate_data_handler(self):
1331    # TODO(b/152094471): Support this with DistIter.get_next_as_optional.
1332    if self._steps_per_execution_value > 1 and self._inferred_steps is None:
1333      raise ValueError(
1334          "Could not infer the size of the data. With "
1335          "`steps_per_execution > 1`, you must specify the number of steps "
1336          "to run.")
1337
1338
1339class _ClusterCoordinatorDataHandler(DataHandler):
1340  """A `DataHandler` that is compatible with `ClusterCoordinator`."""
1341
1342  def __init__(self, x, y=None, **kwargs):
1343    if not isinstance(x, dataset_creator.DatasetCreator):
1344      x = self._convert_to_dataset_creator(x, y, **kwargs)
1345
1346    super().__init__(x=x, **kwargs)
1347
1348  def _convert_to_dataset_creator(self, x, y, **kwargs):
1349    """Converts non-tf.data.Dataset to `DatasetCreator` instances."""
1350
1351    def _dataset_fn(input_context):
1352      del input_context
1353      data_adapter_cls = select_data_adapter(x, y)
1354      return data_adapter_cls(x=x, y=y, **kwargs).get_dataset()
1355
1356    # This check is needed because types like `tf.data.Dataset` don't work with
1357    # PSS yet. So only apply this logic to the types we can support.
1358    if (isinstance(x, _get_tensor_types()) and
1359        isinstance(y, _get_tensor_types())):
1360      return dataset_creator.DatasetCreator(_dataset_fn)
1361    else:
1362      raise NotImplementedError(
1363          "Only `tf.keras.utils.experimental.DatasetCreator`, `tf.Tensor`, "
1364          "numpy arrays and pandas dataframes are supported types at this "
1365          "time.")
1366
1367  def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
1368                                            class_weight, distribute):
1369    if not isinstance(x, dataset_creator.DatasetCreator):
1370      raise TypeError("When using `ParameterServerStrategy`, `x` must be a "
1371                      "`DatasetCreator`.")
1372
1373    def per_worker_dataset_fn():
1374
1375      return strategy.distribute_datasets_from_function(
1376          x, options=x.input_options)
1377
1378    self._dataset = self._model._cluster_coordinator.create_per_worker_dataset(  # pylint: disable=protected-access
1379        per_worker_dataset_fn)
1380
1381    if steps_per_epoch == -1:
1382      self._inferred_steps = None
1383      self._log_indefinite_training_warning()
1384    else:
1385      self._inferred_steps = steps_per_epoch
1386
1387  def sync(self):
1388    self._model._cluster_coordinator.join()  # pylint: disable=protected-access
1389
1390
1391def get_data_handler(*args, **kwargs):
1392  if getattr(kwargs["model"], "_cluster_coordinator", None):
1393    return _ClusterCoordinatorDataHandler(*args, **kwargs)
1394  return DataHandler(*args, **kwargs)
1395
1396
1397def _make_class_weight_map_fn(class_weight):
1398  """Applies class weighting to a `Dataset`.
1399
1400  The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
1401  `y` must be a single `Tensor`.
1402
1403  Args:
1404    class_weight: A map where the keys are integer class ids and values are
1405      the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
1406
1407  Returns:
1408    A function that can be used with `tf.data.Dataset.map` to apply class
1409    weighting.
1410  """
1411  class_ids = list(sorted(class_weight.keys()))
1412  expected_class_ids = list(range(len(class_ids)))
1413  if class_ids != expected_class_ids:
1414    error_msg = (
1415        "Expected `class_weight` to be a dict with keys from 0 to one less "
1416        "than the number of classes, found {}").format(class_weight)
1417    raise ValueError(error_msg)
1418
1419  class_weight_tensor = ops.convert_to_tensor_v2_with_dispatch(
1420      [class_weight[int(c)] for c in class_ids])
1421
1422  def _class_weights_map_fn(*data):
1423    """Convert `class_weight` to `sample_weight`."""
1424    x, y, sw = unpack_x_y_sample_weight(data)
1425
1426    if nest.is_nested(y):
1427      raise ValueError(
1428          "`class_weight` is only supported for Models with a single output.")
1429
1430    if y.shape.rank > 2:
1431      raise ValueError("`class_weight` not supported for "
1432                       "3+ dimensional targets.")
1433
1434    y_classes = smart_cond.smart_cond(
1435        y.shape.rank == 2 and backend.shape(y)[1] > 1,
1436        lambda: backend.argmax(y, axis=1),
1437        lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))
1438
1439    cw = array_ops.gather_v2(class_weight_tensor, y_classes)
1440    if sw is not None:
1441      cw = math_ops.cast(cw, sw.dtype)
1442      sw, cw = expand_1d((sw, cw))
1443      # `class_weight` and `sample_weight` are multiplicative.
1444      sw = sw * cw
1445    else:
1446      sw = cw
1447
1448    return x, y, sw
1449
1450  return _class_weights_map_fn
1451
1452
1453def expand_1d(data):
1454  """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
1455
1456  def _expand_single_1d_tensor(t):
1457    # Leaves `CompositeTensor`s as-is.
1458    if (isinstance(t, ops.Tensor) and
1459        isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1):
1460      return array_ops.expand_dims_v2(t, axis=-1)
1461    return t
1462
1463  return nest.map_structure(_expand_single_1d_tensor, data)
1464
1465
1466def train_validation_split(arrays, validation_split):
1467  """Split arrays into train and validation subsets in deterministic order.
1468
1469  The last part of data will become validation data.
1470
1471  Args:
1472    arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
1473      of Tensors and NumPy arrays.
1474    validation_split: Float between 0 and 1. The proportion of the dataset to
1475      include in the validation split. The rest of the dataset will be included
1476      in the training split.
1477  Returns:
1478    `(train_arrays, validation_arrays)`
1479  """
1480
1481  def _can_split(t):
1482    tensor_types = _get_tensor_types()
1483    return isinstance(t, tensor_types) or t is None
1484
1485  flat_arrays = nest.flatten(arrays)
1486  unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
1487  if unsplitable:
1488    raise ValueError(
1489        "`validation_split` is only supported for Tensors or NumPy "
1490        "arrays, found following types in the input: {}".format(unsplitable))
1491
1492  if all(t is None for t in flat_arrays):
1493    return arrays, arrays
1494
1495  first_non_none = None
1496  for t in flat_arrays:
1497    if t is not None:
1498      first_non_none = t
1499      break
1500
1501  # Assumes all arrays have the same batch shape or are `None`.
1502  batch_dim = int(first_non_none.shape[0])
1503  split_at = int(math.floor(batch_dim * (1. - validation_split)))
1504
1505  if split_at == 0 or split_at == batch_dim:
1506    raise ValueError(
1507        "Training data contains {batch_dim} samples, which is not sufficient "
1508        "to split it into a validation and training set as specified by "
1509        "`validation_split={validation_split}`. Either provide more data, or a "
1510        "different value for the `validation_split` argument." .format(
1511            batch_dim=batch_dim, validation_split=validation_split))
1512
1513  def _split(t, start, end):
1514    if t is None:
1515      return t
1516    return t[start:end]
1517
1518  train_arrays = nest.map_structure(
1519      functools.partial(_split, start=0, end=split_at), arrays)
1520  val_arrays = nest.map_structure(
1521      functools.partial(_split, start=split_at, end=batch_dim), arrays)
1522
1523  return train_arrays, val_arrays
1524
1525
1526@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
1527def unpack_x_y_sample_weight(data):
1528  """Unpacks user-provided data tuple.
1529
1530  This is a convenience utility to be used when overriding
1531  `Model.train_step`, `Model.test_step`, or `Model.predict_step`.
1532  This utility makes it easy to support data of the form `(x,)`,
1533  `(x, y)`, or `(x, y, sample_weight)`.
1534
1535  Standalone usage:
1536
1537  >>> features_batch = tf.ones((10, 5))
1538  >>> labels_batch = tf.zeros((10, 5))
1539  >>> data = (features_batch, labels_batch)
1540  >>> # `y` and `sample_weight` will default to `None` if not provided.
1541  >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1542  >>> sample_weight is None
1543  True
1544
1545  Example in overridden `Model.train_step`:
1546
1547  ```python
1548  class MyModel(tf.keras.Model):
1549
1550    def train_step(self, data):
1551      # If `sample_weight` is not provided, all samples will be weighted
1552      # equally.
1553      x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1554
1555      with tf.GradientTape() as tape:
1556        y_pred = self(x, training=True)
1557        loss = self.compiled_loss(
1558          y, y_pred, sample_weight, regularization_losses=self.losses)
1559        trainable_variables = self.trainable_variables
1560        gradients = tape.gradient(loss, trainable_variables)
1561        self.optimizer.apply_gradients(zip(gradients, trainable_variables))
1562
1563      self.compiled_metrics.update_state(y, y_pred, sample_weight)
1564      return {m.name: m.result() for m in self.metrics}
1565  ```
1566
1567  Args:
1568    data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
1569
1570  Returns:
1571    The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
1572    provided.
1573  """
1574  if not isinstance(data, tuple):
1575    return (data, None, None)
1576  elif len(data) == 1:
1577    return (data[0], None, None)
1578  elif len(data) == 2:
1579    return (data[0], data[1], None)
1580  elif len(data) == 3:
1581    return (data[0], data[1], data[2])
1582  else:
1583    error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
1584                 "or `(x, y, sample_weight)`, found: {}").format(data)
1585    raise ValueError(error_msg)
1586
1587
1588@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
1589def pack_x_y_sample_weight(x, y=None, sample_weight=None):
1590  """Packs user-provided data into a tuple.
1591
1592  This is a convenience utility for packing data into the tuple formats
1593  that `Model.fit` uses.
1594
1595  Standalone usage:
1596
1597  >>> x = tf.ones((10, 1))
1598  >>> data = tf.keras.utils.pack_x_y_sample_weight(x)
1599  >>> isinstance(data, tf.Tensor)
1600  True
1601  >>> y = tf.ones((10, 1))
1602  >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
1603  >>> isinstance(data, tuple)
1604  True
1605  >>> x, y = data
1606
1607  Args:
1608    x: Features to pass to `Model`.
1609    y: Ground-truth targets to pass to `Model`.
1610    sample_weight: Sample weight for each element.
1611
1612  Returns:
1613    Tuple in the format used in `Model.fit`.
1614  """
1615  if y is None:
1616    # For single x-input, we do no tuple wrapping since in this case
1617    # there is no ambiguity. This also makes NumPy and Dataset
1618    # consistent in that the user does not have to wrap their Dataset
1619    # data in an unecessary tuple
1620    if not nest.is_nested(x):
1621      return x
1622    else:
1623      return (x,)
1624  elif sample_weight is None:
1625    return (x, y)
1626  else:
1627    return (x, y, sample_weight)
1628
1629
1630def single_batch_iterator(strategy,
1631                          x,
1632                          y=None,
1633                          sample_weight=None,
1634                          class_weight=None):
1635  """Creates a single-batch dataset."""
1636  x, y, sample_weight = _process_tensorlike((x, y, sample_weight))
1637  if y is None:
1638    data = (x,)
1639  elif sample_weight is None:
1640    data = (x, y)
1641  else:
1642    data = (x, y, sample_weight)
1643
1644  _check_data_cardinality(data)
1645  dataset = dataset_ops.DatasetV2.from_tensors(data)
1646  if class_weight:
1647    dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1648  dataset = strategy.experimental_distribute_dataset(dataset)
1649  return iter(dataset)
1650
1651
1652def _check_data_cardinality(data):
1653  num_samples = set(int(i.shape[0]) for i in nest.flatten(data))
1654  if len(num_samples) > 1:
1655    msg = "Data cardinality is ambiguous:\n"
1656    for label, single_data in zip(["x", "y", "sample_weight"], data):
1657      msg += "  {} sizes: {}\n".format(
1658          label, ", ".join(str(i.shape[0]) for i in nest.flatten(single_data)))
1659    msg += "Make sure all arrays contain the same number of samples."
1660    raise ValueError(msg)
1661
1662
1663def _get_tensor_types():
1664  try:
1665    import pandas as pd  # pylint: disable=g-import-not-at-top
1666
1667    return (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame)
1668  except ImportError:
1669    return (ops.Tensor, np.ndarray)
1670
1671
1672def _is_scipy_sparse(x):
1673  try:
1674    from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
1675
1676    return issparse(x)
1677  except ImportError:
1678    return False
1679
1680
1681def _scipy_sparse_to_sparse_tensor(t):
1682  """Converts a SciPy sparse matrix to a SparseTensor."""
1683  sparse_coo = t.tocoo()
1684  row, col = sparse_coo.row, sparse_coo.col
1685  data, shape = sparse_coo.data, sparse_coo.shape
1686  if issubclass(data.dtype.type, np.floating):
1687    data = data.astype(backend.floatx())
1688  indices = np.concatenate(
1689      (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1)
1690  return sparse_tensor.SparseTensor(indices, data, shape)
1691
1692
1693def _is_distributed_dataset(ds):
1694  return isinstance(ds, input_lib.DistributedDatasetInterface)
1695