xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/training_utils_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
16
17import abc
18import atexit
19import collections
20import functools
21import multiprocessing.pool
22import threading
23import time
24
25import numpy as np
26
27from tensorflow.core.framework import graph_pb2
28from tensorflow.python import tf2
29from tensorflow.python.data.experimental.ops import cardinality
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.data.ops import iterator_ops
32from tensorflow.python.data.ops import options as options_lib
33from tensorflow.python.eager import context
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import smart_cond
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.framework import tensor_util
42from tensorflow.python.keras import backend
43from tensorflow.python.keras import callbacks as cbks
44from tensorflow.python.keras import losses
45from tensorflow.python.keras import metrics as metrics_module
46from tensorflow.python.keras.utils import data_utils
47from tensorflow.python.keras.utils import generic_utils
48from tensorflow.python.keras.utils import losses_utils
49from tensorflow.python.keras.utils import tf_inspect
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import gen_array_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import sparse_ops
54from tensorflow.python.ops.ragged import ragged_tensor
55from tensorflow.python.ops.ragged import ragged_tensor_value
56from tensorflow.python.platform import tf_logging as logging
57from tensorflow.python.util import nest
58
59
60def is_composite_or_composite_value(tensor):
61  """Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
62  # TODO(b/125094323): This should be isinstance(CompositeTensor) or
63  # isinstance(CompositeTensorValue) once we support that.
64  return isinstance(
65      tensor,
66      (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
67       ragged_tensor_value.RaggedTensorValue))
68
69
70class Aggregator(object, metaclass=abc.ABCMeta):
71  """Abstract base class used to aggregate batch-level outputs of a loop.
72
73  Attributes:
74    use_steps: Whether the loop is using `step` or `batch_size`.
75    num_samples: Total number of samples: `batch_size * num_batches`.
76    steps: Total number of steps.
77    batch_size: Batch size. It is used for validation checks between inputs and
78      outputs.
79    results: What to return at the end of the aggregation loop.
80  """
81
82  def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None):
83    self.use_steps = use_steps
84    self.num_samples = num_samples
85    self.steps = steps
86    self.batch_size = batch_size
87    self.results = []
88
89  @abc.abstractmethod
90  def create(self, batch_outs):
91    """Creates the initial results from the first batch outputs.
92
93    Args:
94      batch_outs: A list of batch-level outputs.
95    """
96    raise NotImplementedError('Must be implemented in subclasses.')
97
98  @abc.abstractmethod
99  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
100    """Aggregates batch-level results into total results.
101
102    Args:
103      batch_outs: A list of batch-level outputs.
104      batch_start: The start index of this batch. Always `None` if `use_steps`
105        is `True`.
106      batch_end: The end index of this batch. Always `None` if `use_steps` is
107        `True`.
108    """
109    raise NotImplementedError('Must be implemented in subclasses.')
110
111  @abc.abstractmethod
112  def finalize(self):
113    """Prepares the total results to be returned."""
114    raise NotImplementedError('Must be implemented in subclasses.')
115
116
117class MetricsAggregator(Aggregator):
118  """Aggregator that calculates loss and metrics info.
119
120  Attributes:
121    use_steps: Whether the loop is using `step` or `batch_size`.
122    num_samples: Total number of samples: `batch_size*num_batches`.
123    steps: Total number of steps, ie number of times to iterate over a dataset
124      to cover all samples.
125  """
126
127  def __init__(self, use_steps, num_samples=None, steps=None):
128    super(MetricsAggregator, self).__init__(
129        use_steps=use_steps,
130        num_samples=num_samples,
131        steps=steps,
132        batch_size=None)
133
134  def create(self, batch_outs):
135    self.results = [0.] * len(batch_outs)
136
137  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
138    # Loss.
139    if self.use_steps:
140      self.results[0] += batch_outs[0]
141    else:
142      self.results[0] += batch_outs[0] * (batch_end - batch_start)
143    # Metrics (always stateful, just grab current values.)
144    self.results[1:] = batch_outs[1:]
145
146  def finalize(self):
147    if not self.results:
148      raise ValueError('Empty training data.')
149    self.results[0] /= (self.num_samples or self.steps)
150
151
152def _append_sparse_tensor_value(target, to_append):
153  """Append sparse tensor value objects."""
154  # Make sure the sparse tensors are of the same size (except for the 0th dim).
155  if len(target.dense_shape) != len(to_append.dense_shape):
156    raise RuntimeError(
157        'Unable to concatenate %s and %s. The inner dense shapes do not '
158        'have the same number of dimensions (%s vs %s)' %
159        (target, to_append, target.dense_shape, to_append.dense_shape))
160
161  if target.dense_shape[1:] != to_append.dense_shape[1:]:
162    raise RuntimeError(
163        'Unable to concatenate %s and %s. The inner dense shapes do not '
164        'match inner dimensions (%s vs %s)' %
165        (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:]))
166
167  # Add the to_append indices to target, updating the 0th value, and keeping
168  # track of the maximum so we know the final dense_shape of this tensor.
169  base_dim0_value = target.dense_shape[0]
170  max_dim0_value = target.dense_shape[0]
171  new_indices = target.indices
172  for index in to_append.indices:
173    # Here, we iterate through the sparse indices of the tensor to append. For
174    # each index, we update its zeroth value (the batch index) by adding the
175    # number of batch items in the tensor we are appending to (so an index
176    # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim
177    # size 3 would become [3, 0, 1].)
178    index[0] += base_dim0_value
179    max_dim0_value = max(max_dim0_value, index[0])
180    new_indices = np.append(new_indices, [index], axis=0)
181
182  # Extend the values array to contain all of the appended values. These will
183  # be in the same order as the indices added above.
184  new_values = np.concatenate((target.values, to_append.values), axis=0)
185
186  # Create a new dense shape by replacing the value for the 0th dimension
187  # with the new max dim0 value.
188  new_dense_shape = list(target.dense_shape)
189  new_dense_shape[0] = max_dim0_value + 1
190  new_dense_shape = tuple(new_dense_shape)
191
192  return sparse_tensor.SparseTensorValue(
193      indices=new_indices, values=new_values, dense_shape=new_dense_shape)
194
195
196def _append_ragged_tensor_value(target, to_append):
197  """Append ragged tensor value objects."""
198  # Make sure the ragged tensors are of the same size (save for the 0th dim).
199  if len(target.shape) != len(to_append.shape):
200    raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
201
202  if target.shape[1:] != to_append.shape[1:]:
203    raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
204
205  adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
206  new_row_splits = np.append(target.row_splits, adjusted_row_splits)
207  if isinstance(target.values, ragged_tensor_value.RaggedTensorValue):
208    new_values = _append_ragged_tensor_value(target.values, to_append.values)
209  else:
210    new_values = np.concatenate((target.values, to_append.values), axis=0)
211
212  return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits)
213
214
215def _append_composite_tensor(target, to_append):
216  """Helper function to append composite tensors to each other in the 0 axis.
217
218  In order to support batching within a fit/evaluate/predict call, we need
219  to be able to aggregate within a CompositeTensor. Unfortunately, the CT
220  API currently does not make this easy - especially in V1 mode, where we're
221  working with CompositeTensor Value objects that have no connection with the
222  CompositeTensors that created them.
223
224  Args:
225    target: CompositeTensor or CompositeTensor value object that will be
226      appended to.
227    to_append: CompositeTensor or CompositeTensor value object to append to.
228      'target'.
229
230  Returns:
231    A CompositeTensor or CompositeTensor value object.
232
233  Raises:
234    RuntimeError: if concatenation is not possible.
235  """
236  if type(target) is not type(to_append):
237    raise RuntimeError('Unable to concatenate %s and %s' %
238                       (type(target), type(to_append)))
239
240  # Perform type-specific concatenation.
241  # TODO(b/125094323): This should be replaced by a simple call to
242  # target.append() that should work on all of the below classes.
243
244  # If we're seeing a CompositeTensor here, we know it's because we're in
245  # Eager mode (or else we'd have evaluated the CT to a CT Value object
246  # already). Therefore, it's safe to call concat() on it without evaluating
247  # the result any further. If not - that is, if we're seeing a
248  # SparseTensorValue or a RaggedTensorValue - we need to hand-update it
249  # since we're outside of the graph anyways.
250  if isinstance(target, sparse_tensor.SparseTensor):
251    # We need to invoke the sparse version of concatenate here - tf.concat
252    # won't work.
253    return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0)
254  elif isinstance(target, ragged_tensor.RaggedTensor):
255    return array_ops.concat([target, to_append], axis=0)
256  elif isinstance(target, sparse_tensor.SparseTensorValue):
257    return _append_sparse_tensor_value(target, to_append)
258  elif isinstance(target, ragged_tensor_value.RaggedTensorValue):
259    return _append_ragged_tensor_value(target, to_append)
260  else:
261    raise RuntimeError('Attempted to concatenate unsupported object %s.' %
262                       type(target))
263
264
265class ConcatAggregator(Aggregator):
266  """Combine tensor-likes which cannot be merged on the fly.
267
268  This class expects to aggregate a single tensor-like rather than a nested
269  structure of tensor-likes.
270  """
271
272  def __init__(self, batch_size):
273    self.composite = None
274    super(ConcatAggregator, self).__init__(
275        use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
276
277  def create(self, batch_element):
278    self.composite = is_composite_or_composite_value(batch_element)
279
280  def aggregate(self, batch_element, batch_start=None, batch_end=None):
281
282    # TODO(psv): Add num_samples check here to detect when output batch
283    # #samples is < batch size and != input batch #samples.
284    if self.batch_size and self.batch_size < batch_element.shape[0]:
285      raise ValueError(
286          'Mismatch between expected batch size and model output batch size. '
287          'Output shape = {}, expected output shape = shape {}'.format(
288              batch_element.shape,
289              (self.batch_size,) + batch_element.shape[1:]))
290    self.results.append(batch_element)
291
292  def finalize(self):
293    # Special case of single batch inference which skips a copy.
294    if len(self.results) == 1:
295      self.results = self.results[0]
296
297    elif self.composite:
298      # TODO(taylorrobie): efficiently concatenate.
299      results = self.results[0]
300      for r in self.results[1:]:
301        results = _append_composite_tensor(results, r)
302      self.results = results
303
304    else:
305      self.results = np.concatenate(self.results, axis=0)
306
307
308_COPY_THREADS = 4
309_COPY_POOL = None
310
311
312def get_copy_pool():
313  """Shared threadpool for copying arrays.
314
315  Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
316  creating a pool per SliceAggregator.
317
318  Returns:
319    The global copy threadpool.
320  """
321  global _COPY_POOL
322  if _COPY_POOL is None:
323    _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
324    atexit.register(_COPY_POOL.close)
325  return _COPY_POOL
326
327
328class SliceAggregator(Aggregator):
329  """Combine arrays where the final size is known.
330
331  This class expects to aggregate a single tensor-like rather than a nested
332  structure of tensor-likes.
333
334  NumPy copies are an operation that threads handle quite well because all of
335  the heavy lifting is in c and does not need the GIL. Moreover, we can perform
336  lock-free writes to the same buffer in multiple threads because the nature of
337  result aggregation guarantees that either the indices are disjoint or the
338  aggregator will throw an exception in finalize. Moreover, because aggregation
339  is performed on the slowest varying dimension, assignments for a given batch
340  will write to contiguous blocks of memory, further minimizing contention.
341
342  There is, however, some scheduling and context switching overhead which will
343  offset the gains from pipelining the slice assignment. Below a given threshold
344  it is faster to simply assign in the main thread rather than enqueue the
345  assignment in a side thread. The exact threshold will vary from system to
346  system, but the time is not very sensitive to the exact transition so a value
347  of 2 ** 14 was chosen which should be reasonable on most systems.
348  """
349
350  _BINARY_SIZE_THRESHOLD = 2 ** 14
351  _MAX_COPY_SECONDS = 300
352
353  def __init__(self, num_samples, batch_size):
354    self._async_copies = []
355    self._pool = get_copy_pool()
356    self._errors = []
357    super(SliceAggregator, self).__init__(
358        use_steps=False,
359        num_samples=num_samples,
360        steps=None,
361        batch_size=batch_size)
362
363  def create(self, batch_element):
364    # This step does not need to be pipelined because NumPy empty array
365    # initialization is effectively instantaneous.
366    shape = (self.num_samples,) + batch_element.shape[1:]
367    dtype = batch_element.dtype
368
369    self.results = np.empty(shape=shape, dtype=dtype)
370
371  def aggregate(self, batch_element, batch_start, batch_end):
372    # Fail early.
373    if self._errors:
374      raise self._errors[0]
375
376    # In the special case of single batch inference, no copy is needed.
377    if batch_end - batch_start == self.num_samples:
378      if self.num_samples != batch_element.shape[0]:
379        raise ValueError(
380            'Mismatch between expected batch size and model output batch size. '
381            'Output shape = {}, expected output shape = shape {}'.format(
382                batch_element.shape, self.results.shape))
383
384      self.results = batch_element
385      return
386
387    # This is an approximate threshold, so we don't need to consider the number
388    # of bytes per element.
389    num_elements = np.prod(batch_element.shape)
390    if num_elements < self._BINARY_SIZE_THRESHOLD:
391      self.results[batch_start:batch_end] = batch_element
392    else:
393      is_finished = threading.Event()
394      self._pool.apply_async(
395          self._slice_assign,
396          args=(batch_element, batch_start, batch_end, is_finished))
397      self._async_copies.append(is_finished)
398
399  def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
400    """Legacy utility method to slice input arrays."""
401    try:
402      self.results[batch_start:batch_end] = batch_element
403
404    except Exception as e:  # pylint: disable=broad-except
405      # `_slice_assign` should only be called in threads and exceptions raised
406      # in threads do not carry over to the main thread. So instead we perform a
407      # a broad catch in the thread and then store the exception to be re-raised
408      # in the main thread.
409      self._errors.append(e)
410
411    finally:
412      is_finished.set()
413
414  def finalize(self):
415    start_time = time.time()
416    for is_finished in self._async_copies:
417      timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)])
418      if not is_finished.wait(timeout):
419        raise ValueError('Timed out waiting for copy to complete.')
420
421    if self._errors:
422      raise self._errors[0]
423
424
425class OutputsAggregator(Aggregator):
426  """Aggregator that concatenates outputs."""
427
428  _structure = None
429
430  def create(self, batch_outs):
431    # SparseTensorValue is a named tuple which nest will flatten, so we need
432    # to guard it to properly handle the structure.
433    self._structure = nest.get_traverse_shallow_structure(
434        lambda x: not is_composite_or_composite_value(x), batch_outs)
435    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
436
437    for batch_element in batch_outs:
438      if is_composite_or_composite_value(batch_element):
439        # If the output is not a ndarray, it will be either a composite tensor
440        # or a composite tensor's Value object. In either case, we can't
441        # allocate an array to hold the object - we'll handle it later.
442        self.results.append(ConcatAggregator(self.batch_size))
443      elif isinstance(batch_element, np.ndarray):
444        self.results.append(
445            (ConcatAggregator(self.batch_size) if self.use_steps else
446             SliceAggregator(self.num_samples, self.batch_size)))
447      else:
448        # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue.
449        # Fail fast rather than trying to concatenate it.
450        raise RuntimeError('Attempted to aggregate unsupported object {}.'
451                           .format(batch_element))
452
453      self.results[-1].create(batch_element)
454
455  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
456    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
457    for batch_element, result in zip(batch_outs, self.results):
458      result.aggregate(batch_element, batch_start, batch_end)
459
460  def finalize(self):
461    for result in self.results:
462      result.finalize()
463    self.results = [i.results for i in self.results]
464    self.results = nest.pack_sequence_as(self._structure, self.results)
465
466
467def get_progbar(model, count_mode, include_metrics=True):
468  """Get Progbar."""
469  if include_metrics:
470    stateful_metric_names = getattr(model, 'metrics_names', None)
471    if stateful_metric_names:
472      stateful_metric_names = stateful_metric_names[1:]  # Exclude `loss`
473  else:
474    stateful_metric_names = None
475  return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)
476
477
478def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
479  """Determine the number of samples provided for training and evaluation.
480
481  The number of samples is not defined when running with `steps`,
482  in which case the number of samples is set to `None`.
483
484  Args:
485      ins: List of tensors to be fed to the Keras function.
486      batch_size: Integer batch size or `None` if not defined.
487      steps: Total number of steps (batches of samples) before declaring
488        `_predict_loop` finished. Ignored with the default value of `None`.
489      steps_name: The public API's parameter name for `steps`.
490
491  Raises:
492      ValueError: when `steps` is `None` and the attribute `ins.shape`
493      does not exist. Also raises ValueError when `steps` is not `None`
494      and `batch_size` is not `None` because they are mutually
495      exclusive.
496
497  Returns:
498      When steps is `None`, returns the number of samples to be
499      processed based on the size of the first dimension of the
500      first input numpy array. When steps is not `None` and
501      `batch_size` is `None`, returns `None`.
502  """
503  if steps is not None and batch_size is not None:
504    raise ValueError('If ' + steps_name +
505                     ' is set, the `batch_size` must be None.')
506  if check_steps_argument(ins, steps, steps_name):
507    return None
508
509  if hasattr(ins[0], 'shape'):
510    return int(ins[0].shape[0])
511  return None  # Edge case where ins == [static_learning_phase]
512
513
514def standardize_single_array(x, expected_shape=None):
515  """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
516  if x is None:
517    return None
518
519  if is_composite_or_composite_value(x):
520    return x
521
522  if isinstance(x, int):
523    raise ValueError(
524        'Expected an array data type but received an integer: {}'.format(x))
525
526  if (x.shape is not None and len(x.shape) == 1 and
527      (expected_shape is None or len(expected_shape) != 1)):
528    if tensor_util.is_tf_type(x):
529      x = array_ops.expand_dims(x, axis=1)
530    else:
531      x = np.expand_dims(x, 1)
532  return x
533
534
535def get_composite_shape(tensor):
536  """Returns the shape of the passed composite tensor."""
537  if isinstance(tensor, sparse_tensor.SparseTensorValue):
538    # SparseTensorValues use a 'dense_shape' attribute
539    return tensor.dense_shape
540  else:
541    return tensor.shape
542
543
544def standardize_input_data(data,
545                           names,
546                           shapes=None,
547                           check_batch_axis=True,
548                           exception_prefix=''):
549  """Normalizes inputs and targets provided by users.
550
551  Users may pass data as a list of arrays, dictionary of arrays,
552  or as a single array. We normalize this to an ordered list of
553  arrays (same order as `names`), while checking that the provided
554  arrays have shapes that match the network's expectations.
555
556  Args:
557      data: User-provided input data (polymorphic).
558      names: List of expected array names.
559      shapes: Optional list of expected array shapes.
560      check_batch_axis: Boolean; whether to check that the batch axis of the
561        arrays matches the expected value found in `shapes`.
562      exception_prefix: String prefix used for exception formatting.
563
564  Returns:
565      List of standardized input arrays (one array per model input).
566
567  Raises:
568      ValueError: in case of improperly formatted user-provided data.
569  """
570  try:
571    data_len = len(data)
572  except TypeError:
573    # For instance if data is `None` or a symbolic Tensor.
574    data_len = None
575
576  if not names:
577    if data_len and not isinstance(data, dict):
578      raise ValueError(
579          'Error when checking model ' + exception_prefix + ': '
580          'expected no data, but got:', data)
581    return []
582  if data is None:
583    return [None for _ in range(len(names))]
584
585  if isinstance(data, dict):
586    try:
587      data = [
588          data[x].values
589          if data[x].__class__.__name__ == 'DataFrame' else data[x]
590          for x in names
591      ]
592    except KeyError as e:
593      raise ValueError('No data provided for "' + e.args[0] + '". Need data '
594                       'for each key in: ' + str(names))
595  elif isinstance(data, (list, tuple)):
596    if isinstance(data[0], (list, tuple)):
597      data = [np.asarray(d) for d in data]
598    elif len(names) == 1 and isinstance(data[0], (float, int)):
599      data = [np.asarray(data)]
600    else:
601      data = [
602          x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
603      ]
604  else:
605    data = data.values if data.__class__.__name__ == 'DataFrame' else data
606    data = [data]
607
608  if shapes is not None:
609    data = [
610        standardize_single_array(x, shape) for (x, shape) in zip(data, shapes)
611    ]
612  else:
613    data = [standardize_single_array(x) for x in data]
614
615  if len(data) != len(names):
616    if data and hasattr(data[0], 'shape'):
617      raise ValueError('Error when checking model ' + exception_prefix +
618                       ': the list of Numpy arrays that you are passing to '
619                       'your model is not the size the model expected. '
620                       'Expected to see ' + str(len(names)) + ' array(s), ' +
621                       'for inputs ' + str(names) + ' but instead got the '
622                       'following list of ' + str(len(data)) + ' arrays: ' +
623                       str(data)[:200] + '...')
624    elif len(names) > 1:
625      raise ValueError('Error when checking model ' + exception_prefix +
626                       ': you are passing a list as input to your model, '
627                       'but the model expects a list of ' + str(len(names)) +
628                       ' Numpy arrays instead. The list you passed was: ' +
629                       str(data)[:200])
630    elif len(data) == 1 and not hasattr(data[0], 'shape'):
631      raise TypeError('Error when checking model ' + exception_prefix +
632                      ': data should be a Numpy array, or list/dict of '
633                      'Numpy arrays. Found: ' + str(data)[:200] + '...')
634    elif len(names) == 1:
635      data = [np.asarray(data)]
636
637  # Check shapes compatibility.
638  if shapes:
639    for i in range(len(names)):
640      if shapes[i] is not None:
641        if tensor_util.is_tf_type(data[i]):
642          tensorshape = data[i].shape
643          if not tensorshape:
644            continue
645          data_shape = tuple(tensorshape.as_list())
646        elif is_composite_or_composite_value(data[i]):
647          tensorshape = get_composite_shape(data[i])
648          data_shape = tuple(tensorshape.as_list())
649        else:
650          data_shape = data[i].shape
651
652        shape = shapes[i]
653        if len(data_shape) != len(shape):
654          raise ValueError('Error when checking ' + exception_prefix +
655                           ': expected ' + names[i] + ' to have ' +
656                           str(len(shape)) + ' dimensions, but got array '
657                           'with shape ' + str(data_shape))
658        if not check_batch_axis:
659          data_shape = data_shape[1:]
660          shape = shape[1:]
661        for dim, ref_dim in zip(data_shape, shape):
662          if ref_dim != dim and ref_dim is not None and dim is not None:
663            raise ValueError('Error when checking ' + exception_prefix +
664                             ': expected ' + names[i] + ' to have shape ' +
665                             str(shape) + ' but got array with shape ' +
666                             str(data_shape))
667  return data
668
669
670def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
671  """Maps `sample_weight` or `class_weight` to model outputs.
672
673  Args:
674      x_weight: User-provided `sample_weight` or `class_weight` argument.
675      output_names: List of output names (strings) in the model.
676      weight_type: A string used purely for exception printing.
677
678  Returns:
679      A list of `sample_weight` or `class_weight` where there are exactly
680          one element per model output.
681
682  Raises:
683      ValueError: In case of invalid user-provided argument.
684  """
685  if x_weight is None or (isinstance(x_weight, (list, tuple)) and
686                          len(x_weight) == 0):  # pylint: disable=g-explicit-length-test
687    return [None for _ in output_names]
688  if len(output_names) == 1:
689    if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
690      return x_weight
691    if isinstance(x_weight, dict) and output_names[0] in x_weight:
692      return [x_weight[output_names[0]]]
693    else:
694      return [x_weight]
695  if isinstance(x_weight, (list, tuple)):
696    if len(x_weight) != len(output_names):
697      raise ValueError('Provided `' + weight_type + '` was a list of ' +
698                       str(len(x_weight)) + ' elements, but the model has ' +
699                       str(len(output_names)) + ' outputs. '
700                       'You should provide one `' + weight_type + '`'
701                       'array per model output.')
702    return x_weight
703  if isinstance(x_weight, collections.abc.Mapping):
704    generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
705    x_weights = []
706    for name in output_names:
707      x_weights.append(x_weight.get(name))
708    return x_weights
709  else:
710    raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
711                    'should be either a list or a dict. '
712                    'Provided `' + weight_type + '` type not understood: ' +
713                    str(x_weight))
714
715
716def standardize_class_weights(class_weight, output_names):
717  return standardize_sample_or_class_weights(class_weight, output_names,
718                                             'class_weight')
719
720
721def standardize_sample_weights(sample_weight, output_names):
722  return standardize_sample_or_class_weights(sample_weight, output_names,
723                                             'sample_weight')
724
725
726def check_array_lengths(inputs, targets, weights=None):
727  """Does user input validation for numpy arrays.
728
729  Args:
730      inputs: list of Numpy arrays of inputs.
731      targets: list of Numpy arrays of targets.
732      weights: list of Numpy arrays of sample weights.
733
734  Raises:
735      ValueError: in case of incorrectly formatted data.
736  """
737
738  def is_tensor_or_composite_tensor(x):
739    return tensor_util.is_tf_type(x) or is_composite_or_composite_value(x)
740
741  def set_of_lengths(x):
742    # Returns a set with the variation between
743    # different shapes, with None => 0
744    if x is None:
745      return {}
746    else:
747      return set([
748          y.shape[0]
749          for y in x
750          if y is not None and not is_tensor_or_composite_tensor(y)
751      ])
752
753  set_x = set_of_lengths(inputs)
754  set_y = set_of_lengths(targets)
755  set_w = set_of_lengths(weights)
756  if len(set_x) > 1:
757    raise ValueError('All input arrays (x) should have '
758                     'the same number of samples. Got array shapes: ' +
759                     str([x.shape for x in inputs]))
760  if len(set_y) > 1:
761    raise ValueError('All target arrays (y) should have '
762                     'the same number of samples. Got array shapes: ' +
763                     str([y.shape for y in targets]))
764  if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
765    raise ValueError('Input arrays should have '
766                     'the same number of samples as target arrays. '
767                     'Found ' + str(list(set_x)[0]) + ' input samples '
768                     'and ' + str(list(set_y)[0]) + ' target samples.')
769  if len(set_w) > 1:
770    raise ValueError('All sample_weight arrays should have '
771                     'the same number of samples. Got array shapes: ' +
772                     str([w.shape for w in weights]))
773  if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
774    raise ValueError('Sample_weight arrays should have '
775                     'the same number of samples as target arrays. Got ' +
776                     str(list(set_y)[0]) + ' input samples and ' +
777                     str(list(set_w)[0]) + ' target samples.')
778
779
780def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
781  """Does validation on the compatibility of targets and loss functions.
782
783  This helps prevent users from using loss functions incorrectly. This check
784  is purely for UX purposes.
785
786  Args:
787      targets: list of Numpy arrays of targets.
788      loss_fns: list of loss functions.
789      output_shapes: list of shapes of model outputs.
790
791  Raises:
792      ValueError: if a loss function or target array
793          is incompatible with an output.
794  """
795  key_loss_fns = {
796      losses.mean_squared_error, losses.binary_crossentropy,
797      losses.categorical_crossentropy
798  }
799  key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
800                      losses.CategoricalCrossentropy)
801  for y, loss, shape in zip(targets, loss_fns, output_shapes):
802    if y is None or loss is None or tensor_util.is_tf_type(y):
803      continue
804    if losses.is_categorical_crossentropy(loss):
805      if y.shape[-1] == 1:
806        raise ValueError('You are passing a target array of shape ' +
807                         str(y.shape) +
808                         ' while using as loss `categorical_crossentropy`. '
809                         '`categorical_crossentropy` expects '
810                         'targets to be binary matrices (1s and 0s) '
811                         'of shape (samples, classes). '
812                         'If your targets are integer classes, '
813                         'you can convert them to the expected format via:\n'
814                         '```\n'
815                         'from keras.utils import to_categorical\n'
816                         'y_binary = to_categorical(y_int)\n'
817                         '```\n'
818                         '\n'
819                         'Alternatively, you can use the loss function '
820                         '`sparse_categorical_crossentropy` instead, '
821                         'which does expect integer targets.')
822
823    is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
824    if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
825                                               (loss.fn in key_loss_fns))):
826      for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
827        if out_dim is not None and target_dim != out_dim:
828          loss_name = loss.name
829          if loss_name is None:
830            loss_type = loss.fn if is_loss_wrapper else type(loss)
831            loss_name = loss_type.__name__
832          raise ValueError('A target array with shape ' + str(y.shape) +
833                           ' was passed for an output of shape ' + str(shape) +
834                           ' while using as loss `' + loss_name + '`. '
835                           'This loss expects targets to have the same shape '
836                           'as the output.')
837
838
839def collect_per_output_metric_info(metrics,
840                                   output_names,
841                                   output_shapes,
842                                   loss_fns,
843                                   from_serialized=False,
844                                   is_weighted=False):
845  """Maps metric names and functions to model outputs.
846
847  Args:
848      metrics: a list or a list of lists or a dict of metric functions.
849      output_names: a list of the names (strings) of model outputs.
850      output_shapes: a list of the shapes (strings) of model outputs.
851      loss_fns: a list of the loss functions corresponding to the model outputs.
852      from_serialized: whether the model the metrics are being sourced from is
853        being initialized from a serialized format.
854      is_weighted: Boolean indicating whether the given metrics are weighted.
855
856  Returns:
857      A list (one entry per model output) of dicts.
858      For instance, if the model has 2 outputs, and for the first output
859      we want to compute "binary_accuracy" and "binary_crossentropy",
860      and just "binary_accuracy" for the second output,
861      the list would look like: `[{
862          'acc': binary_accuracy(),
863          'ce': binary_crossentropy(),
864        }, {
865          'acc': binary_accuracy(),
866        }]`
867
868  Raises:
869      TypeError: if an incorrect type is passed for the `metrics` argument.
870  """
871  if not metrics:
872    return [{} for _ in output_names]
873
874  if isinstance(metrics, list):
875    any_sub_list = any(isinstance(m, list) for m in metrics)
876    if any_sub_list:
877      if len(metrics) != len(output_names):
878        raise ValueError('When passing a list of lists as `metrics`, '
879                         'it should have one entry per model output. '
880                         'The model has ' + str(len(output_names)) +
881                         ' outputs, but you passed metrics=' + str(metrics))
882      # User has provided a list of len = len(outputs).
883      nested_metrics = [generic_utils.to_list(m) for m in metrics]
884    else:
885      # If it is a single list we then apply all metrics to all outputs.
886      if len(output_names) > 1:
887        nested_metrics = []
888        for _ in output_names:
889          nested_metrics.append(
890              [metrics_module.clone_metric(m) for m in metrics])
891      else:
892        nested_metrics = [metrics]
893  elif isinstance(metrics, collections.abc.Mapping):
894    generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
895    nested_metrics = []
896    for name in output_names:
897      output_metrics = generic_utils.to_list(metrics.get(name, []))
898      nested_metrics.append(output_metrics)
899  else:
900    raise TypeError('Type of `metrics` argument not understood. '
901                    'Expected a list or dictionary, found: ' + str(metrics))
902
903  per_output_metrics = []
904  for i, metrics in enumerate(nested_metrics):
905    metrics_dict = collections.OrderedDict()
906    for metric in metrics:
907      metric_name = get_metric_name(metric, is_weighted)
908      metric_fn = get_metric_function(
909          metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
910      metric_fn._from_serialized = from_serialized  # pylint: disable=protected-access
911
912      # If the metric function is not stateful, we create a stateful version.
913      if not isinstance(metric_fn, metrics_module.Metric):
914        metric_fn = metrics_module.MeanMetricWrapper(
915            metric_fn, name=metric_name)
916        # If the metric is being revived from something stateless, such as a
917        # string (e.g. "accuracy"), we may need to later reapply transformations
918        # such as renaming.
919        metric_fn._from_serialized = False  # pylint: disable=protected-access
920      metrics_dict[metric_name] = metric_fn
921    per_output_metrics.append(metrics_dict)
922
923  return per_output_metrics
924
925
926def batch_shuffle(index_array, batch_size):
927  """Shuffles an array in a batch-wise fashion.
928
929  Useful for shuffling HDF5 arrays
930  (where one cannot access arbitrary indices).
931
932  Args:
933      index_array: array of indices to be shuffled.
934      batch_size: integer.
935
936  Returns:
937      The `index_array` array, shuffled in a batch-wise fashion.
938  """
939  batch_count = int(len(index_array) / batch_size)
940  # to reshape we need to be cleanly divisible by batch size
941  # we stash extra items and reappend them after shuffling
942  last_batch = index_array[batch_count * batch_size:]
943  index_array = index_array[:batch_count * batch_size]
944  index_array = index_array.reshape((batch_count, batch_size))
945  np.random.shuffle(index_array)
946  index_array = index_array.flatten()
947  return np.append(index_array, last_batch)
948
949
950def standardize_weights(y,
951                        sample_weight=None,
952                        class_weight=None,
953                        sample_weight_mode=None):
954  """Performs sample weight validation and standardization.
955
956  Everything gets normalized to a single sample-wise (or timestep-wise)
957  weight array. If both `sample_weight` and `class_weight` are provided,
958  the weights are multiplied.
959
960  Args:
961      y: Numpy array or Tensor of model targets to be weighted.
962      sample_weight: User-provided `sample_weight` argument.
963      class_weight: User-provided `class_weight` argument.
964      sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated
965        that we expect 2D weight data that will be applied to the last 2
966        dimensions of the targets (i.e. we are weighting timesteps, not
967        samples).
968
969  Returns:
970      A numpy array of target weights, one entry per sample to weight.
971
972  Raises:
973      ValueError: In case of invalid user-provided arguments.
974  """
975  # Iterator may return sample_weight as 1-tuple
976  if isinstance(sample_weight, tuple):
977    sample_weight = sample_weight[0]
978  if sample_weight_mode is not None and sample_weight_mode != 'samplewise':
979    if sample_weight_mode != 'temporal':
980      raise ValueError('"sample_weight_mode '
981                       'should be None or "temporal". '
982                       'Found: ' + str(sample_weight_mode))
983    if len(y.shape) < 3:
984      raise ValueError('Found a sample_weight array for '
985                       'an input with shape ' + str(y.shape) + '. '
986                       'Timestep-wise sample weighting (use of '
987                       'sample_weight_mode="temporal") is restricted to '
988                       'outputs that are at least 3D, i.e. that have '
989                       'a time dimension.')
990    if sample_weight is not None and len(sample_weight.shape) != 2:
991      raise ValueError('Found a sample_weight array with shape ' +
992                       str(sample_weight.shape) + '. '
993                       'In order to use timestep-wise sample weighting, '
994                       'you should pass a 2D sample_weight array.')
995  else:
996    if sample_weight is not None and len(sample_weight.shape) != 1:
997      raise ValueError(
998          'Found a sample_weight array with shape {}. In order to '
999          'use timestep-wise sample weights, you should specify '
1000          'sample_weight_mode="temporal" in compile(); founssd "{}" '
1001          'instead. If you just mean to use sample-wise weights, '
1002          'make sure your sample_weight array is 1D.'.format(
1003              sample_weight.shape, sample_weight_mode))
1004
1005  if sample_weight is not None:
1006    if len(sample_weight.shape) > len(y.shape):
1007      raise ValueError('Found a sample_weight with shape' +
1008                       str(sample_weight.shape) + '.'
1009                       'Expected sample_weight with rank '
1010                       'less than or equal to ' + str(len(y.shape)))
1011
1012    if (not tensor_util.is_tf_type(sample_weight) and
1013        y.shape[:sample_weight.ndim] != sample_weight.shape):
1014      raise ValueError('Found a sample_weight array with shape ' +
1015                       str(sample_weight.shape) + ' for an input with shape ' +
1016                       str(y.shape) + '. '
1017                       'sample_weight cannot be broadcast.')
1018
1019  # Class weights applied per-sample.
1020  class_sample_weight = None
1021  if isinstance(class_weight, dict):
1022    if len(y.shape) > 2:
1023      raise ValueError('`class_weight` not supported for '
1024                       '3+ dimensional targets.')
1025
1026    if tensor_util.is_tf_type(y):
1027      # Few classes are expected, so densifying is reasonable.
1028      keys = np.array(sorted(class_weight.keys()))
1029      values = np.array([class_weight[i] for i in keys])
1030      weight_vector = np.zeros(np.max(keys) + 1)
1031      weight_vector[:] = np.nan
1032      weight_vector[keys] = values
1033
1034      y_classes = smart_cond.smart_cond(
1035          len(y.shape.as_list()) == 2 and backend.shape(y)[1] > 1,
1036          lambda: backend.argmax(y, axis=1),
1037          lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))
1038      class_sample_weight = array_ops.gather(weight_vector, y_classes)
1039      gen_array_ops.check_numerics(
1040          class_sample_weight,
1041          'Invalid classes or class weights detected. NaN values indicate that '
1042          'an appropriate class weight could not be determined.')
1043      class_sample_weight = math_ops.cast(class_sample_weight, backend.floatx())
1044      if sample_weight is not None:
1045        sample_weight = math_ops.cast(
1046            ops.convert_to_tensor_v2_with_dispatch(sample_weight),
1047            backend.floatx())
1048    else:
1049      y_classes = y
1050      if len(y.shape) == 2:
1051        if y.shape[1] > 1:
1052          y_classes = np.argmax(y, axis=1)
1053        elif y.shape[1] == 1:
1054          y_classes = np.reshape(y, y.shape[0])
1055
1056      class_sample_weight = np.asarray(
1057          [class_weight[cls] for cls in y_classes if cls in class_weight])
1058
1059      if len(class_sample_weight) != len(y_classes):
1060        # subtract the sets to pick all missing classes
1061        existing_classes = set(y_classes)
1062        existing_class_weight = set(class_weight.keys())
1063        raise ValueError(
1064            '`class_weight` must contain all classes in the data.'
1065            ' The classes %s exist in the data but not in '
1066            '`class_weight`.' % (existing_classes - existing_class_weight))
1067
1068  if class_sample_weight is not None and sample_weight is not None:
1069    # Multiply weights if both are provided.
1070    return class_sample_weight * sample_weight
1071  if sample_weight is not None:
1072    return sample_weight
1073  if class_sample_weight is not None:
1074    return class_sample_weight
1075  return None
1076
1077
1078def has_symbolic_tensors(ls):
1079  if context.executing_eagerly():
1080    return False
1081  return has_tensors(ls)
1082
1083
1084def has_tensors(ls):
1085  """Returns true if `ls` contains tensors."""
1086  # Note: at some point in time ragged tensors didn't count as tensors, so this
1087  # returned false for ragged tensors. Making this return true fails some tests
1088  # which would then require a steps_per_epoch argument.
1089  if isinstance(ls, (list, tuple)):
1090    return any(
1091        tensor_util.is_tf_type(v) and
1092        not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
1093  if isinstance(ls, dict):
1094    return any(
1095        tensor_util.is_tf_type(v) and
1096        not isinstance(v, ragged_tensor.RaggedTensor)
1097        for _, v in ls.items())
1098  return tensor_util.is_tf_type(ls) and not isinstance(
1099      ls, ragged_tensor.RaggedTensor)
1100
1101
1102def get_metric_name(metric, weighted=False):
1103  """Returns the name corresponding to the given metric input.
1104
1105  Args:
1106    metric: Metric function name or reference.
1107    weighted: Boolean indicating if the given metric is weighted.
1108
1109  Returns:
1110      The metric name.
1111  """
1112  if tf2.enabled():
1113    # We keep the string that the user has set in compile as the metric name.
1114    if isinstance(metric, str):
1115      return metric
1116
1117    metric = metrics_module.get(metric)
1118    return metric.name if hasattr(metric, 'name') else metric.__name__
1119  else:
1120    metric_name_prefix = 'weighted_' if weighted else ''
1121    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
1122      if metric in ('accuracy', 'acc'):
1123        suffix = 'acc'
1124      elif metric in ('crossentropy', 'ce'):
1125        suffix = 'ce'
1126    else:
1127      metric_fn = metrics_module.get(metric)
1128      # Get metric name as string
1129      if hasattr(metric_fn, 'name'):
1130        suffix = metric_fn.name
1131      else:
1132        suffix = metric_fn.__name__
1133    metric_name = metric_name_prefix + suffix
1134    return metric_name
1135
1136
1137def get_metric_function(metric, output_shape=None, loss_fn=None):
1138  """Returns the metric function corresponding to the given metric input.
1139
1140  Args:
1141      metric: Metric function name or reference.
1142      output_shape: The shape of the output that this metric will be calculated
1143        for.
1144      loss_fn: The loss function used.
1145
1146  Returns:
1147      The metric function.
1148  """
1149  if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
1150    return metrics_module.get(metric)
1151
1152  is_sparse_categorical_crossentropy = (
1153      isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or
1154      (isinstance(loss_fn, losses.LossFunctionWrapper) and
1155       loss_fn.fn == losses.sparse_categorical_crossentropy))
1156
1157  is_binary_crossentropy = (
1158      isinstance(loss_fn, losses.BinaryCrossentropy) or
1159      (isinstance(loss_fn, losses.LossFunctionWrapper) and
1160       loss_fn.fn == losses.binary_crossentropy))
1161
1162  if metric in ['accuracy', 'acc']:
1163    if output_shape[-1] == 1 or is_binary_crossentropy:
1164      return metrics_module.binary_accuracy
1165    elif is_sparse_categorical_crossentropy:
1166      return metrics_module.sparse_categorical_accuracy
1167    # If the output_shape[-1] is not 1, then we know output is `categorical`.
1168    # We assume it is sparse categorical only if loss is explicitly given
1169    # as sparse categorical crossentropy loss.
1170    return metrics_module.categorical_accuracy
1171  else:
1172    if output_shape[-1] == 1 or is_binary_crossentropy:
1173      return metrics_module.binary_crossentropy
1174    elif is_sparse_categorical_crossentropy:
1175      return metrics_module.sparse_categorical_crossentropy
1176    return metrics_module.categorical_crossentropy
1177
1178
1179def call_metric_function(metric_fn,
1180                         y_true,
1181                         y_pred=None,
1182                         weights=None,
1183                         mask=None):
1184  """Invokes metric function and returns the metric result tensor."""
1185  if mask is not None:
1186    mask = math_ops.cast(mask, y_pred.dtype)
1187    if weights is None:
1188      # Use mask as sample weight.
1189      weights = mask
1190    else:
1191      # Update dimensions of weights to match with mask.
1192      weights = math_ops.cast(weights, dtype=y_pred.dtype)
1193      mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
1194          mask, sample_weight=weights)
1195      weights *= mask
1196
1197  if y_pred is not None:
1198    return metric_fn(y_true, y_pred, sample_weight=weights)
1199  # `Mean` metric only takes a single value.
1200  return metric_fn(y_true, sample_weight=weights)
1201
1202
1203def get_loss_function(loss):
1204  """Returns the loss corresponding to the loss input in `compile` API."""
1205  if loss is None or isinstance(loss, losses.Loss):
1206    return loss
1207
1208  if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
1209    # It is not safe to assume that the loss takes no constructor arguments.
1210    raise ValueError(
1211        'Received uninstantiated Loss class: {}\nPlease call loss ""classes '
1212        'before passing them to Model.compile.'.format(loss))
1213
1214  # Deserialize loss configuration, if needed.
1215  if isinstance(loss, collections.abc.Mapping):
1216    loss = losses.get(loss)
1217
1218  # Custom callable class.
1219  if callable(loss) and not hasattr(loss, '__name__'):
1220    return loss
1221
1222  # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
1223  # in `LossFunctionWrapper` class.
1224  loss_fn = losses.get(loss)
1225
1226  # For losses which are given as strings/functions in the compile API,
1227  # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
1228  # (both in distribution strategy context and otherwise).
1229  return losses.LossFunctionWrapper(
1230      loss_fn,
1231      name=loss_fn.__name__,
1232      reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)
1233
1234
1235def validate_dataset_input(x, y, sample_weight, validation_split=None):
1236  """Validates user input arguments when a dataset iterator is passed.
1237
1238  Args:
1239    x: Input data. A `tf.data` dataset or iterator.
1240    y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
1241      Expected to be `None` when `x` is a dataset iterator.
1242    sample_weight: An optional sample-weight array passed by the user to weight
1243      the importance of each sample in `x`. Expected to be `None` when `x` is a
1244      dataset iterator
1245    validation_split: Float between 0 and 1. Fraction of the training data to be
1246      used as validation data. Expected to be `None` when `x` is a dataset
1247      iterator.
1248
1249  Raises:
1250    ValueError: if argument `y` or `sample_weight` or `validation_split` are
1251        provided by user.
1252  """
1253  if y is not None:
1254    raise ValueError('You passed a dataset or dataset iterator (%s) as '
1255                     'input `x` to your model. In that case, you should '
1256                     'not specify a target (`y`) argument, since the dataset '
1257                     'or dataset iterator generates both input data and '
1258                     'target data. '
1259                     'Received: %s' % (x, y))
1260  if sample_weight is not None:
1261    raise ValueError('`sample_weight` argument is not supported when input '
1262                     '`x` is a dataset or a dataset iterator. Instead, you'
1263                     'can provide sample_weight as the third element  of your'
1264                     'dataset, i.e. (inputs, targets, sample_weight). '
1265                     'Received: x=%s, sample_weight=%s' % (x, sample_weight))
1266  if validation_split is not None and validation_split != 0.0:
1267    raise ValueError(
1268        '`validation_split` argument is not supported when '
1269        'input `x` is a dataset or a dataset iterator. '
1270        'Received: x=%s, validation_split=%f' % (x, validation_split))
1271
1272
1273def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'):
1274  """Helper function to validate either inputs or targets."""
1275  if isinstance(inp, (list, tuple)):
1276    if not all(isinstance(v, np.ndarray) or
1277               tensor_util.is_tf_type(v) for v in inp):
1278      raise ValueError(
1279          'Please provide as model inputs either a single array or a list of '
1280          'arrays. You passed: {}={}'.format(field_name, str(orig_inp)))
1281  elif isinstance(inp, dict):
1282    if not allow_dict:
1283      raise ValueError(
1284          'You cannot pass a dictionary as model {}.'.format(field_name))
1285  elif not isinstance(inp, np.ndarray) and not tensor_util.is_tf_type(inp):
1286    raise ValueError(
1287        'Please provide as model inputs either a single array or a list of '
1288        'arrays. You passed: {}={}'.format(field_name, orig_inp))
1289
1290
1291def check_generator_arguments(y=None, sample_weight=None,
1292                              validation_split=None):
1293  """Validates arguments passed when using a generator."""
1294  if y is not None:
1295    raise ValueError('`y` argument is not supported when data is'
1296                     'a generator or Sequence instance. Instead pass targets'
1297                     ' as the second element of the generator.')
1298  if sample_weight is not None:
1299    raise ValueError('`sample_weight` argument is not supported when data is'
1300                     'a generator or Sequence instance. Instead pass sample'
1301                     ' weights as the third element of the generator.')
1302  if validation_split:
1303    raise ValueError('If your data is in the form of a Python generator, '
1304                     'you cannot use `validation_split`.')
1305
1306
1307def check_steps_argument(input_data, steps, steps_name):
1308  """Validates `steps` argument based on input data's type.
1309
1310  The cases when `steps` value must be provided are when
1311    1. input data passed is an iterator.
1312    2. model was built on top of symbolic tensors, input data is not
1313       required and is `None`.
1314    3. input data passed is a symbolic tensor.
1315
1316  Args:
1317      input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
1318        tf.data.Dataset iterator or `None`.
1319      steps: Integer or `None`. Total number of steps (batches of samples) to
1320        execute.
1321      steps_name: The public API's parameter name for `steps`.
1322
1323  Returns:
1324    boolean, True if `steps` argument is required, else False.
1325
1326  Raises:
1327      ValueError: if `steps` argument is required for given input data type
1328        but not provided.
1329  """
1330  is_x_iterator = isinstance(
1331      input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase))
1332  if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
1333      (isinstance(input_data, list) and not input_data)):
1334    if steps is None:
1335      input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors'
1336      raise ValueError('When using {input_type} as input to a model, you should'
1337                       ' specify the `{steps_name}` argument.'.format(
1338                           input_type=input_type_str, steps_name=steps_name))
1339    return True
1340
1341  if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
1342    return True
1343
1344  if steps is not None:
1345    list_types = (np.ndarray, list, tuple)
1346    if (isinstance(input_data, list_types) or
1347        (isinstance(input_data, dict) and
1348         any(isinstance(v, list_types) for v in input_data.values()))):
1349      logging.warning('When passing input data as arrays, do not specify '
1350                      '`steps_per_epoch`/`steps` argument. '
1351                      'Please use `batch_size` instead.')
1352  return False
1353
1354
1355def cast_single_tensor(x, dtype=None):
1356  if isinstance(x, np.ndarray):
1357    x = ops.convert_to_tensor_v2_with_dispatch(x)
1358  dtype = dtype or backend.floatx()
1359  if x.dtype.is_floating:
1360    return math_ops.cast(x, dtype=dtype)
1361  return x
1362
1363
1364def cast_if_floating_dtype_and_mismatch(targets, outputs):
1365  """Returns target data tensors using correct datatype.
1366
1367  Checks that each target and output pair are the same datatype. If not, casts
1368  the target to the output's datatype.
1369
1370  Args:
1371    targets: tensor or list of targets.
1372    outputs: tensor or list of outputs.
1373
1374  Returns:
1375    Targets in appropriate datatype.
1376  """
1377  if tensor_util.is_tf_type(targets):
1378    # There is one target, so output[0] should be the only output.
1379    return cast_single_tensor(targets, dtype=outputs[0].dtype)
1380  new_targets = []
1381  for target, out in zip(targets, outputs):
1382    if isinstance(target, np.ndarray):
1383      target = ops.convert_to_tensor_v2_with_dispatch(target)
1384    if target.dtype != out.dtype:
1385      new_targets.append(cast_single_tensor(target, dtype=out.dtype))
1386    else:
1387      new_targets.append(target)
1388  return new_targets
1389
1390
1391def cast_if_floating_dtype(x, dtype=None):
1392  """Casts the given data tensors to the default floating point type.
1393
1394  Casts only if the input is already a floating point type.
1395  Args:
1396    x: tensor or list/tuple of tensors.
1397    dtype: The dtype to which Tensors should be cast.
1398
1399  Returns:
1400    Converted input.
1401  """
1402  return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype),
1403                            x)
1404
1405
1406def cast_to_model_input_dtypes(x, model):
1407  """Casts the given data tensors to the dtypes of the model inputs.
1408
1409  Args:
1410    x: tensor or list/tuple of tensors.
1411    model: The model.
1412
1413  Returns:
1414    Converted input. Each tensor is casted to the corresponding input in
1415    `model.inputs`.
1416  """
1417  input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs)
1418  return nest.map_structure(math_ops.cast, x, input_dtypes)
1419
1420
1421def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
1422  """Prepares sample weight modes for the model.
1423
1424  Args:
1425    training_endpoints: List of model _TrainingEndpoints.
1426    sample_weight_mode: sample weight mode user input passed from compile API.
1427
1428  Raises:
1429    ValueError: In case of invalid `sample_weight_mode` input.
1430  """
1431
1432  if isinstance(sample_weight_mode, collections.abc.Mapping):
1433    generic_utils.check_for_unexpected_keys(
1434        'sample_weight_mode', sample_weight_mode,
1435        [e.output_name for e in training_endpoints])
1436
1437    for end_point in training_endpoints:
1438      if not end_point.should_skip_target_weights():
1439        if end_point.output_name not in sample_weight_mode:
1440          raise ValueError('Output ' + end_point.output_name +
1441                           'missing from `_sample_weight_modes` dictionary')
1442        else:
1443          end_point.sample_weight_mode = sample_weight_mode.get(
1444              end_point.output_name)
1445  elif isinstance(sample_weight_mode, (list, tuple)):
1446    if len(sample_weight_mode) != len(training_endpoints):
1447      raise ValueError('When passing a list as sample_weight_mode, '
1448                       'it should have one entry per model output. '
1449                       'The model has ' + str(len(training_endpoints)) +
1450                       ' outputs, but you passed ' +
1451                       str(len(sample_weight_mode)) + '_sample_weight_modes.')
1452    for mode, endpoint in zip(sample_weight_mode, training_endpoints):
1453      if not endpoint.should_skip_target_weights():
1454        endpoint.sample_weight_mode = mode
1455  else:
1456    for endpoint in training_endpoints:
1457      if not endpoint.should_skip_target_weights():
1458        endpoint.sample_weight_mode = sample_weight_mode
1459
1460
1461def prepare_loss_functions(loss, output_names):
1462  """Converts loss to a list of loss functions.
1463
1464  Args:
1465      loss: String (name of objective function), objective function or
1466        `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple
1467        outputs, you can use a different loss on each output by passing a
1468        dictionary or a list of losses. The loss value that will be minimized by
1469        the model will then be the sum of all individual losses.
1470      output_names: List of model output names.
1471
1472  Returns:
1473      A list of loss objective functions.
1474
1475  Raises:
1476      ValueError: If loss is a dict with keys not in model output names,
1477          or if loss is a list with len not equal to model outputs.
1478  """
1479  if isinstance(loss, collections.abc.Mapping):
1480    generic_utils.check_for_unexpected_keys('loss', loss, output_names)
1481    loss_functions = []
1482    for name in output_names:
1483      if name not in loss:
1484        logging.warning(
1485            'Output {0} missing from loss dictionary. We assume '
1486            'this was done on purpose. The fit and evaluate APIs will not be '
1487            'expecting any data to be passed to {0}.'.format(name))
1488      loss_functions.append(get_loss_function(loss.get(name, None)))
1489  elif isinstance(loss, str):
1490    loss_functions = [get_loss_function(loss) for _ in output_names]
1491  elif isinstance(loss, collections.abc.Sequence):
1492    if len(loss) != len(output_names):
1493      raise ValueError('When passing a list as loss, it should have one entry '
1494                       'per model outputs. The model has {} outputs, but you '
1495                       'passed loss={}'.format(len(output_names), loss))
1496    loss_functions = nest.map_structure(get_loss_function, loss)
1497  else:
1498    loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]
1499
1500  return loss_functions
1501
1502
1503def prepare_loss_weights(training_endpoints, loss_weights=None):
1504  """Converts loss weights to a list of loss weights.
1505
1506  The result loss weights will be populated on the training endpoint.
1507
1508  Args:
1509      training_endpoints: List of model training endpoints.
1510      loss_weights: Optional list or dictionary specifying scalar coefficients
1511        (Python floats) to weight the loss contributions of different model
1512        outputs. The loss value that will be minimized by the model will then be
1513        the *weighted sum* of all individual losses, weighted by the
1514          `loss_weights` coefficients. If a list, it is expected to have a 1:1
1515            mapping to the model's outputs. If a dict, it is expected to map
1516            output names (strings) to scalar coefficients.
1517
1518  Raises:
1519      ValueError: If loss weight is a dict with key not in model output names,
1520          or if loss is a list with len not equal to model outputs.
1521  """
1522  if loss_weights is None:
1523    for e in training_endpoints:
1524      e.loss_weight = 1.
1525  elif isinstance(loss_weights, collections.abc.Mapping):
1526    generic_utils.check_for_unexpected_keys(
1527        'loss_weights', loss_weights,
1528        [e.output_name for e in training_endpoints])
1529    for e in training_endpoints:
1530      e.loss_weight = loss_weights.get(e.output_name, 1.)
1531  elif isinstance(loss_weights, list):
1532    if len(loss_weights) != len(training_endpoints):
1533      raise ValueError('When passing a list as loss_weights, '
1534                       'it should have one entry per model output. '
1535                       'The model has ' + str(len(training_endpoints)) +
1536                       ' outputs, but you passed loss_weights=' +
1537                       str(loss_weights))
1538    for w, e in zip(loss_weights, training_endpoints):
1539      e.loss_weight = w
1540  else:
1541    raise TypeError('Could not interpret loss_weights argument: ' +
1542                    str(loss_weights) + ' - expected a list of dicts.')
1543
1544
1545# TODO(rohanj): This is a hack to get around not depending on feature_column and
1546# create a cyclical dependency. Figure out a cleaner solution
1547def is_feature_layer(layer):
1548  """Returns whether `layer` is a FeatureLayer or not."""
1549  return getattr(layer, '_is_feature_layer', False)
1550
1551
1552def is_eager_dataset_or_iterator(data):
1553  return context.executing_eagerly() and isinstance(
1554      data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
1555             iterator_ops.IteratorBase))
1556
1557
1558# pylint: disable=protected-access
1559def get_dataset_graph_def(dataset):
1560  if context.executing_eagerly():
1561    graph_def_str = dataset._as_serialized_graph().numpy()
1562  else:
1563    graph_def_str = backend.get_value(dataset._as_serialized_graph())
1564  return graph_pb2.GraphDef().FromString(graph_def_str)
1565
1566
1567def verify_dataset_shuffled(x):
1568  """Verifies that the dataset is shuffled.
1569
1570  Args:
1571    x: Dataset passed as an input to the model.
1572
1573  Returns:
1574    boolean, whether the input dataset is shuffled or not.
1575  """
1576  assert isinstance(x, dataset_ops.DatasetV2)
1577  graph_def = get_dataset_graph_def(x)
1578  for node in graph_def.node:
1579    if node.op.startswith('ShuffleDataset'):
1580      return True
1581  # Also check graph_def.library.function for ds.interleave or ds.flat_map
1582  for function in graph_def.library.function:
1583    for node in function.node_def:
1584      if node.op.startswith('ShuffleDataset'):
1585        return True
1586  logging.warning('Expected a shuffled dataset but input dataset `x` is '
1587                  'not shuffled. Please invoke `shuffle()` on input dataset.')
1588  return False
1589
1590
1591def is_dataset_or_iterator(data):
1592  return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
1593                           iterator_ops.Iterator, iterator_ops.IteratorBase))
1594
1595
1596def get_iterator(dataset):
1597  """Create and initialize an iterator from a dataset."""
1598  if context.executing_eagerly():
1599    iterator = dataset_ops.make_one_shot_iterator(dataset)
1600  else:
1601    iterator = dataset_ops.make_initializable_iterator(dataset)
1602  initialize_iterator(iterator)
1603  return iterator
1604
1605
1606def initialize_iterator(iterator):
1607  if not context.executing_eagerly():
1608    init_op = iterator.initializer
1609    backend.get_session((init_op,)).run(init_op)
1610
1611
1612def extract_tensors_from_dataset(dataset):
1613  """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset.
1614
1615  Args:
1616    dataset: Dataset instance.
1617
1618  Returns:
1619    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1620  """
1621  iterator = get_iterator(dataset)
1622  inputs, targets, sample_weight = unpack_iterator_input(iterator)
1623  return inputs, targets, sample_weight
1624
1625
1626def unpack_iterator_input(iterator):
1627  """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
1628
1629  Args:
1630    iterator: Instance of a dataset iterator.
1631
1632  Returns:
1633    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1634  """
1635  try:
1636    next_element = iterator.get_next()
1637  except errors.OutOfRangeError:
1638    raise RuntimeError('Your dataset iterator ran out of data; '
1639                       'Make sure that your dataset can generate '
1640                       'required number of samples.')
1641
1642  if isinstance(next_element, (list, tuple)):
1643    if len(next_element) not in [2, 3]:
1644      raise ValueError(
1645          'Please provide model inputs as a list or tuple of 2 or 3 '
1646          'elements: (input, target) or (input, target, sample_weights) '
1647          'Received %s' % next_element)
1648    if len(next_element) == 2:
1649      x, y = next_element
1650      weights = None
1651    else:
1652      x, y, weights = next_element
1653  else:
1654    x = next_element
1655    y = None
1656    weights = None
1657  return x, y, weights
1658
1659
1660def infer_steps_for_dataset(model,
1661                            dataset,
1662                            steps,
1663                            epochs=1,
1664                            steps_name='steps'):
1665  """Infers steps_per_epoch needed to loop through a dataset.
1666
1667  Args:
1668      model: Keras model instance.
1669      dataset: Input data of type tf.data.Dataset.
1670      steps: Number of steps to draw from the dataset (may be None if unknown).
1671      epochs: Number of times to iterate over the dataset.
1672      steps_name: The string name of the steps argument, either `steps`,
1673        `validation_steps`, or `steps_per_epoch`. Only used for error message
1674        formatting.
1675
1676  Returns:
1677    Integer or `None`. Inferred number of steps to loop through the dataset.
1678    `None` is returned if 1) the size of the dataset is unknown and `steps` was
1679    not specified, or 2) this is multi-worker training and auto sharding is
1680    enabled.
1681
1682  Raises:
1683    ValueError: In case of invalid argument values.
1684  """
1685  assert isinstance(dataset, dataset_ops.DatasetV2)
1686  if (model._in_multi_worker_mode() and
1687      (dataset.options().experimental_distribute.auto_shard_policy !=
1688       options_lib.AutoShardPolicy.OFF)):
1689    # If the dataset would be auto-sharded, we should not infer a local
1690    # steps_per_epoch due to the possible inbalanced sharding between workers.
1691    return None
1692
1693  size = backend.get_value(cardinality.cardinality(dataset))
1694  if size == cardinality.INFINITE and steps is None:
1695    raise ValueError('When passing an infinitely repeating dataset, you '
1696                     'must specify the `%s` argument.' % (steps_name,))
1697  if size >= 0:
1698    if steps is not None and steps * epochs > size:
1699      if epochs > 1:
1700        raise ValueError('The dataset you passed contains %s batches, but you '
1701                         'passed `epochs=%s` and `%s=%s`, which is a total of '
1702                         '%s steps. We cannot draw that many steps from this '
1703                         'dataset. We suggest to set `%s=%s`.' %
1704                         (size, epochs, steps_name, steps, steps * epochs,
1705                          steps_name, size // epochs))
1706      else:
1707        raise ValueError('The dataset you passed contains %s batches, but you '
1708                         'passed `%s=%s`. We cannot draw that many steps from '
1709                         'this dataset. We suggest to set `%s=%s`.' %
1710                         (size, steps_name, steps, steps_name, size))
1711  if steps is None:
1712    if size >= 0:
1713      return size
1714    return None
1715  return steps
1716
1717
1718class ModelInputs(object):
1719  """Encapsulates model inputs.
1720
1721  Allows for transforming model inputs while keeping the same structure.
1722  """
1723
1724  def __init__(self, inputs):
1725    self._inputs = inputs
1726    self._is_dict = isinstance(self._inputs, dict)
1727    self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
1728
1729    self._flattened_inputs = []
1730    self._input_names = []
1731
1732    if self._is_dict:
1733      for k in sorted(self._inputs.keys()):
1734        self._flattened_inputs.append(self._inputs[k])
1735        self._input_names.append(k)
1736    else:
1737      self._flattened_inputs = nest.flatten(self._inputs)
1738      self._input_names = [
1739          'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
1740      ]
1741
1742  def get_input_names(self):
1743    """Returns keys to name inputs by.
1744
1745    In case inputs provided were a list, tuple or single entry, we make up a
1746    key 'input_%d'. For dictionary case, we return a sorted list of keys.
1747    """
1748    return self._input_names
1749
1750  def get_symbolic_inputs(self, return_single_as_list=False):
1751    """Returns inputs to be set as self.inputs for a model."""
1752    # TODO(karmel): There is a side-effect here where what you get
1753    # with as_list and as_dict depends on whether you have called this
1754    # method first, since it modifies in place.
1755    for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)):
1756      if isinstance(v, (list, float, int)):
1757        v = np.asarray(v)
1758        if v.ndim == 1:
1759          v = np.expand_dims(v, 1)
1760
1761      if isinstance(v, np.ndarray):
1762        # We fix the placeholder shape except the batch size.
1763        # This is suboptimal, but it is the best we can do with the info
1764        # we have. The user should call `model._set_inputs(placeholders)`
1765        # to specify custom placeholders if the need arises.
1766        shape = (None,) + tuple(v.shape[1:])
1767        if shape == (None,):
1768          shape = (None, 1)
1769        dtype = dtypes.as_dtype(v.dtype)
1770        if dtype.is_floating:
1771          dtype = backend.floatx()
1772        v = backend.placeholder(shape=shape, name=k, dtype=dtype)
1773      elif isinstance(v, tensor_spec.TensorSpec):
1774        shape = (None,) + tuple(v.shape.as_list()[1:])
1775        if shape == (None,):
1776          shape = (None, 1)
1777        v = backend.placeholder(shape=shape, name=k, dtype=v.dtype)
1778
1779      self._flattened_inputs[i] = v
1780
1781    if self._is_dict:
1782      return dict(zip(self._input_names, self._flattened_inputs))
1783    if self._is_single_input and not return_single_as_list:
1784      return self._flattened_inputs[0]
1785    return self._flattened_inputs
1786
1787  def as_dict(self):
1788    """An iterable over a dictionary version of inputs."""
1789    for k, v in zip(self._input_names, self._flattened_inputs):
1790      yield k, v
1791
1792  def as_list(self):
1793    """Returning the inputs as a list."""
1794    return self._flattened_inputs
1795
1796
1797# Allow use of methods not exposed to the user.
1798# pylint: disable=protected-access
1799
1800
1801# pylint: enable=protected-access
1802
1803
1804def generic_output_names(outputs_list):
1805  return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
1806
1807
1808def should_run_validation(validation_freq, epoch):
1809  """Checks if validation should be run this epoch.
1810
1811  Args:
1812    validation_freq: Integer or list. If an integer, specifies how many training
1813      epochs to run before a new validation run is performed. If a list,
1814      specifies the epochs on which to run validation.
1815    epoch: Integer, the number of the training epoch just completed.
1816
1817  Returns:
1818    Bool, True if validation should be run.
1819
1820  Raises:
1821    ValueError: if `validation_freq` is an Integer and less than 1, or if
1822    it is neither an Integer nor a Sequence.
1823  """
1824  # `epoch` is 0-indexed internally but 1-indexed in the public API.
1825  one_indexed_epoch = epoch + 1
1826
1827  if isinstance(validation_freq, int):
1828    if validation_freq < 1:
1829      raise ValueError('`validation_freq` can not be less than 1.')
1830    return one_indexed_epoch % validation_freq == 0
1831
1832  if not isinstance(validation_freq, collections.abc.Container):
1833    raise ValueError('`validation_freq` must be an Integer or '
1834                     '`collections.abc.Container` (e.g. list, tuple, etc.)')
1835  return one_indexed_epoch in validation_freq
1836
1837
1838def split_training_and_validation_data(x, y, sample_weights, validation_split):
1839  """Split input data into train/eval section based on validation_split."""
1840  if has_symbolic_tensors(x):
1841    raise ValueError('If your data is in the form of symbolic tensors, '
1842                     'you cannot use `validation_split`.')
1843  if hasattr(x[0], 'shape'):
1844    split_at = int(x[0].shape[0] * (1. - validation_split))
1845  else:
1846    split_at = int(len(x[0]) * (1. - validation_split))
1847  x, val_x = (generic_utils.slice_arrays(x, 0, split_at),
1848              generic_utils.slice_arrays(x, split_at))
1849  y, val_y = (generic_utils.slice_arrays(y, 0, split_at),
1850              generic_utils.slice_arrays(y, split_at))
1851  if sample_weights:
1852    sample_weights, val_sample_weights = (
1853        generic_utils.slice_arrays(sample_weights, 0, split_at),
1854        generic_utils.slice_arrays(sample_weights, split_at),
1855    )
1856  else:
1857    val_sample_weights = None
1858  return x, y, sample_weights, val_x, val_y, val_sample_weights
1859
1860
1861def unpack_validation_data(validation_data, raise_if_ambiguous=True):
1862  """Unpack validation data based input type.
1863
1864  The validation data is not touched if its dataset or dataset iterator.
1865  For other type of input (Numpy or tensor), it will be unpacked into tuple of
1866  3 which is x, y and sample weights.
1867
1868  Args:
1869    validation_data: dataset, dataset iterator, or numpy, tensor tuple.
1870    raise_if_ambiguous: boolean on whether to fail if validation_data cannot be
1871      parsed. Otherwise simply return validation_data, None, None and defer the
1872      decision to the caller.
1873
1874  Returns:
1875    tuple of 3, (x, y, sample_weights) for numpy and tensor input.
1876  """
1877  if (isinstance(validation_data, (iterator_ops.Iterator,
1878                                   iterator_ops.IteratorBase,
1879                                   dataset_ops.DatasetV2,
1880                                   data_utils.Sequence))
1881      or not hasattr(validation_data, '__len__')):
1882    val_x = validation_data
1883    val_y = None
1884    val_sample_weight = None
1885  elif len(validation_data) == 2:
1886    try:
1887      val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
1888      val_sample_weight = None
1889    except ValueError:
1890      val_x, val_y, val_sample_weight = validation_data, None, None
1891  elif len(validation_data) == 3:
1892    try:
1893      val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
1894    except ValueError:
1895      val_x, val_y, val_sample_weight = validation_data, None, None
1896  else:
1897    if raise_if_ambiguous:
1898      raise ValueError(
1899          'When passing a `validation_data` argument, '
1900          'it must contain either 2 items (x_val, y_val), '
1901          'or 3 items (x_val, y_val, val_sample_weights), '
1902          'or alternatively it could be a dataset or a '
1903          'dataset or a dataset iterator. '
1904          'However we received `validation_data=%s`' % validation_data)
1905    val_x, val_y, val_sample_weight = validation_data, None, None
1906  return val_x, val_y, val_sample_weight
1907
1908
1909class TrainingLoop(object):
1910  """TrainingLoop is a wrapper class around the training logic.
1911
1912  This class is trying to encapsulate the different logic of fit/eval/predict
1913  with regard to different data input and model condition.
1914
1915  Note that TrainingLoop is stateless, which means it doesn't contain any
1916  internal field and can be reused with different model and inputs.
1917  """
1918
1919  def fit(self,
1920          model,
1921          x=None,
1922          y=None,
1923          batch_size=None,
1924          epochs=1,
1925          verbose=1,
1926          callbacks=None,
1927          validation_split=0.,
1928          validation_data=None,
1929          shuffle=True,
1930          class_weight=None,
1931          sample_weight=None,
1932          initial_epoch=0,
1933          steps_per_epoch=None,
1934          validation_steps=None,
1935          validation_freq=1,
1936          **kwargs):
1937    """Train the model with the inputs and targets."""
1938    raise NotImplementedError()
1939
1940  def evaluate(self,
1941               model,
1942               x=None,
1943               y=None,
1944               batch_size=None,
1945               verbose=1,
1946               sample_weight=None,
1947               steps=None,
1948               callbacks=None,
1949               **kwargs):
1950    """Returns the loss value & metrics values for the model in test mode."""
1951    raise NotImplementedError()
1952
1953  def predict(self,
1954              model,
1955              x,
1956              batch_size=None,
1957              verbose=0,
1958              steps=None,
1959              callbacks=None,
1960              **kwargs):
1961    raise NotImplementedError()
1962