xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/training_arrays_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to plain array data."""
16# pylint: disable=protected-access
17
18import functools
19
20import numpy as np
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.ops import iterator_ops
24from tensorflow.python.eager import context
25from tensorflow.python.framework import errors
26from tensorflow.python.keras import backend
27from tensorflow.python.keras import callbacks as cbks
28from tensorflow.python.keras.distribute import distributed_training_utils_v1
29from tensorflow.python.keras.engine import training_utils_v1
30from tensorflow.python.keras.utils.generic_utils import make_batches
31from tensorflow.python.keras.utils.generic_utils import slice_arrays
32from tensorflow.python.keras.utils.mode_keys import ModeKeys
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import nest
35
36try:
37  from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
38except ImportError:
39  issparse = None
40
41
42def model_iteration(model,
43                    inputs,
44                    targets=None,
45                    sample_weights=None,
46                    batch_size=None,
47                    epochs=1,
48                    verbose=1,
49                    callbacks=None,
50                    val_inputs=None,
51                    val_targets=None,
52                    val_sample_weights=None,
53                    shuffle=True,
54                    initial_epoch=0,
55                    steps_per_epoch=None,
56                    validation_steps=None,
57                    validation_freq=1,
58                    mode=ModeKeys.TRAIN,
59                    validation_in_fit=False,
60                    prepared_feed_values_from_dataset=False,
61                    steps_name='steps',
62                    **kwargs):
63  """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
64
65  Args:
66      model: Keras Model instance.
67      inputs: Either a list or dictionary of arrays, or a dataset instance.
68      targets: List/dictionary of input arrays.
69      sample_weights: Optional list of sample weight arrays.
70      batch_size: Integer batch size or None if unknown.
71      epochs: Number of times to iterate over the data
72      verbose: 0, 1, or 2. Verbosity mode.
73        0 = silent, 1 = progress bar, 2 = one line per epoch.
74        Note that the progress bar is not particularly useful when
75        logged to a file, so verbose=2 is recommended when not running
76        interactively (eg, in a production environment).
77      callbacks: List of callbacks to be called during training
78      val_inputs: Either a list or dictionary of arrays, or a dataset instance.
79      val_targets: List/dictionary of target arrays.
80      val_sample_weights: Optional list of sample weight arrays.
81      shuffle: Whether to shuffle the data at the beginning of each epoch
82        concatenation of list the display names of the outputs of `f` and the
83        list of display names of the outputs of `f_val`.
84      initial_epoch: Epoch at which to start training (useful for resuming a
85        previous training run)
86      steps_per_epoch: Total number of steps (batches of samples) before
87        declaring one epoch finished and starting the next epoch. Ignored with
88        the default value of `None`.
89      validation_steps: Number of steps to run validation for (only if doing
90        validation from data tensors). Ignored with the default value of
91        `None`.
92      validation_freq: Only relevant if validation data is provided. Integer or
93        `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
94        integer, specifies how many training epochs to run before a new
95        validation run is performed, e.g. `validation_freq=2` runs
96        validation every 2 epochs. If a Container, specifies the epochs on
97        which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
98        validation at the end of the 1st, 2nd, and 10th epochs.
99      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
100      validation_in_fit: if true, then this method is invoked from within
101        training iteration (for validation). In the case where `val_inputs` is
102        a dataset, this flag indicates that its iterator and feed values are
103        already created so should properly reuse resources.
104      prepared_feed_values_from_dataset: if True, `inputs` is a list of feed
105        tensors returned from `_prepare_feed_values` call on the validation
106        dataset, so do not call it again on `inputs`. Should only be used for
107        inline validation (i.e., only if `validation_in_fit` is also True).
108      steps_name: The string name of the steps argument, either `steps`,
109        `validation_steps`, or `steps_per_epoch`. Only used for error message
110        formatting.
111      **kwargs: Additional arguments for backwards compatibility.
112
113  Returns:
114      - In TRAIN mode: `History` object.
115      - In TEST mode: Evaluation metrics.
116      - In PREDICT mode: Outputs of the Model called on inputs.
117
118  Raises:
119      ValueError: in case of invalid arguments.
120  """
121  # Backwards compatibility.
122  if 'steps' in kwargs:
123    steps_per_epoch = kwargs.pop('steps')
124  if kwargs:
125    raise TypeError('Unknown arguments: %s' % (kwargs,))
126
127  # In case we were passed a dataset, we extract symbolic tensors from it.
128  reset_dataset_after_each_epoch = False
129  input_iterator = None
130  is_dataset = isinstance(inputs,
131                          (dataset_ops.DatasetV1, dataset_ops.DatasetV2))
132  # TODO(fchollet): consider moving `steps_per_epoch` inference to
133  # _standardize_user_data and set reset_dataset_after_each_epoch as an
134  # attribute on the dataset instance.
135  if is_dataset:
136    if steps_per_epoch is None:
137      reset_dataset_after_each_epoch = True
138      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
139          model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
140    input_iterator = _get_iterator(inputs, model._distribution_strategy)
141
142  # Enter tf.distribute.Strategy scope.
143  if model._distribution_strategy:
144    scope = distributed_training_utils_v1.distributed_scope(
145        strategy=model._distribution_strategy,
146        learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
147    scope.__enter__()
148
149  use_steps = is_dataset or steps_per_epoch is not None
150  do_validation = val_inputs is not None
151
152  # Prepare input data.
153  inputs = input_iterator or inputs
154  if validation_in_fit and prepared_feed_values_from_dataset:
155    # When invoking validation in training loop, avoid creating iterator and
156    # list of feed values for the same validation dataset multiple times (which
157    # essentially would call `iterator.get_next()` that slows down execution and
158    # leads to OOM errors eventually.
159    ins = inputs
160  else:
161    ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode)
162    # `ins` is a function when a distribute strategy is used in Eager mode.  In
163    # that case `is_dataset` is True.  The code branches that have requirements
164    # about the type of `ins` do not trigger in the distributed case.
165
166  if not is_dataset:
167    num_samples_or_steps = _get_num_samples_or_steps(ins, batch_size,
168                                                     steps_per_epoch)
169  else:
170    num_samples_or_steps = steps_per_epoch
171
172  # Update sample_weight_mode of the model if sample_weights is specified by the
173  # user. We need to call this function after we have a handle on the inputs
174  # (both numpy arrays and datasets) in order to determine if the user has
175  # specified sample_weights.
176  _update_sample_weight_mode(model, mode, ins)
177
178  # Get step function and loop type. As part of building the execution
179  # function we recompile the metrics based on the updated
180  # sample_weight_mode value.
181  f = _make_execution_function(model, mode)
182
183  # Prepare validation data. Hold references to the iterator and the input list
184  # to properly reinitialize and reuse in multiple validation passes.
185  val_iterator = None
186  if isinstance(val_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
187    if validation_steps is None:
188      # Because we pass an iterator feed instead of a Dataset to the eval
189      # model_iteration() call, it will not trigger the dataset-input path
190      # that determines the number of steps required. To avoid this issue,
191      # set validation_steps here if validation_steps is None.
192      validation_steps = training_utils_v1.infer_steps_for_dataset(
193          model,
194          val_inputs,
195          validation_steps,
196          epochs=epochs,
197          steps_name='validation_steps')
198    val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
199    val_inputs = _prepare_feed_values(
200        model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST)
201    # Get num steps for printing.
202    val_samples_or_steps = validation_steps
203  else:
204    # Get num samples for printing.
205    val_samples_or_steps = val_inputs and nest.flatten(
206        val_inputs)[0].shape[0] or None
207
208  if mode == ModeKeys.TRAIN and verbose:
209    _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset)
210
211  # Configure callbacks.
212  count_mode = 'steps' if use_steps else 'samples'
213  callbacks = cbks.configure_callbacks(
214      callbacks,
215      model,
216      do_validation=do_validation,
217      batch_size=batch_size,
218      epochs=epochs,
219      steps_per_epoch=steps_per_epoch,
220      samples=num_samples_or_steps,
221      count_mode=count_mode,
222      verbose=verbose,
223      mode=mode)
224
225  # Find beforehand arrays that need sparse-to-dense conversion.
226  if issparse is not None and not use_steps:
227    indices_for_conversion_to_dense = []
228    feed = _get_model_feed(model, mode)
229    for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)):
230      if issparse(input_data) and not backend.is_sparse(feed_tensor):
231        indices_for_conversion_to_dense.append(i)
232
233  # Select aggregation method.
234  if mode == ModeKeys.PREDICT:
235    aggregator = training_utils_v1.OutputsAggregator(
236        use_steps,
237        num_samples=None if steps_per_epoch else num_samples_or_steps,
238        steps=steps_per_epoch)
239  else:
240    aggregator = training_utils_v1.MetricsAggregator(
241        use_steps,
242        num_samples=None if steps_per_epoch else num_samples_or_steps,
243        steps=steps_per_epoch)
244
245  if model._compile_distribution:
246    distributed_training_utils_v1._copy_weights_to_distributed_model(
247        model, mode)
248
249  callbacks.model.stop_training = False
250  callbacks._call_begin_hook(mode)
251
252  initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
253
254  for epoch in range(initial_epoch, epochs):
255    if callbacks.model.stop_training:
256      break
257
258    # Setup work for each epoch
259    epoch_logs = {}
260    if mode != ModeKeys.PREDICT:
261      # Collecting and resetting metrics has non-zero cost and will needlessly
262      # slow down model.predict.
263      model.reset_metrics()
264    if mode == ModeKeys.TRAIN:
265      callbacks.on_epoch_begin(epoch, epoch_logs)
266
267    if use_steps:
268      # Step-wise loop.
269      if steps_per_epoch is None:
270        # Loop over dataset until `OutOfRangeError` is raised.
271        target_steps = np.inf
272      else:
273        # Loop over dataset for the specified number of steps.
274        target_steps = steps_per_epoch
275
276      step = 0
277      while step < target_steps:
278        batch_logs = {'batch': step, 'size': 1}
279        callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
280
281        # Get outputs.
282        try:
283          # `ins` can be callable in tf.distribute.Strategy + eager case.
284          if not callable(ins) or (model._distribution_strategy and
285                                   not distributed_training_utils_v1
286                                   .is_distributing_by_cloning(model)):
287            actual_inputs = ins
288          else:
289            actual_inputs = ins()
290          batch_outs = f(actual_inputs)
291        except errors.OutOfRangeError:
292          if is_dataset:
293            # The dataset passed by the user ran out of batches.
294            # Now we know the cardinality of the dataset.
295            # If steps_per_epoch was specified, then running out of data is
296            # unexpected, so we stop training and inform the user.
297            if steps_per_epoch:
298              callbacks.model.stop_training = True
299              logging.warning(
300                  'Your dataset ran out of data; interrupting training. '
301                  'Make sure that your dataset can generate at least '
302                  '`%s * epochs` batches (in this case, %d batches). '
303                  'You may need to use the repeat() function when '
304                  'building your dataset.'
305                  % (steps_name, steps_per_epoch * epochs))
306            elif step > 0:
307              steps_per_epoch = step
308              aggregator.steps = steps_per_epoch
309          else:
310            # We ran out of batches while the user passed an iterator (legacy).
311            callbacks.model.stop_training = True
312            logging.warning(
313                'Your dataset iterator ran out of data; '
314                'interrupting training. Make sure that your iterator '
315                'can generate at least `%s * epochs` '
316                'batches (in this case, %d batches). You may need to'
317                'use the repeat() function when building your '
318                'dataset.' % (steps_name, steps_per_epoch * epochs))
319          break
320
321        if not isinstance(batch_outs, list):
322          batch_outs = [batch_outs]
323
324        if model._distribution_strategy:
325          batch_outs = (
326              distributed_training_utils_v1._per_replica_aggregate_batch(
327                  model._distribution_strategy, batch_outs, model, mode))
328
329        # Aggregate results.
330        if step == 0:
331          aggregator.create(batch_outs)
332        aggregator.aggregate(batch_outs)
333
334        # Callbacks batch end.
335        batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
336        callbacks._call_batch_hook(mode, 'end', step, batch_logs)
337        step += 1
338
339        if callbacks.model.stop_training:
340          break
341    else:
342      # Sample-wise loop.
343      index_array = np.arange(num_samples_or_steps)
344      if shuffle == 'batch':
345        index_array = training_utils_v1.batch_shuffle(index_array, batch_size)
346      elif shuffle:
347        np.random.shuffle(index_array)
348      batches = make_batches(num_samples_or_steps, batch_size)
349      for batch_index, (batch_start, batch_end) in enumerate(batches):
350        batch_ids = index_array[batch_start:batch_end]
351        # Slice into a batch.
352        if len(batches) == 1:
353          # If we only have one batch, do not slice. This takes care of
354          # composite tensors in non-Dataset modes; we currently don't support
355          # slicing them.
356          # TODO(b/133517906): Add slicing support.
357          ins_batch = ins
358        else:
359          try:
360            if ins and isinstance(ins[-1], int):
361              # Do not slice the training phase flag.
362              ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
363            else:
364              ins_batch = slice_arrays(ins, batch_ids)
365          except TypeError:
366            raise TypeError('TypeError while preparing batch. '
367                            'If using HDF5 input data, '
368                            'pass shuffle="batch".')
369
370        # Sparse to dense conversion.
371        if issparse is not None:
372          for i in indices_for_conversion_to_dense:
373            ins_batch[i] = ins_batch[i].toarray()
374
375        # Callbacks batch_begin.
376        batch_logs = {'batch': batch_index, 'size': len(batch_ids)}
377        callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs)
378
379        # Get outputs.
380        batch_outs = f(ins_batch)
381        if not isinstance(batch_outs, list):
382          batch_outs = [batch_outs]
383
384        # Aggregate results.
385        if batch_index == 0:
386          aggregator.create(batch_outs)
387        aggregator.aggregate(batch_outs, batch_start, batch_end)
388
389        # Callbacks batch end.
390        batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
391        callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs)
392
393        if callbacks.model.stop_training:
394          break
395
396    aggregator.finalize()
397    results = aggregator.results
398    epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
399    if len(results) == 1:
400      results = results[0]
401
402    # Run the test loop every `validation_freq` epochs during training.
403    if (do_validation and
404        training_utils_v1.should_run_validation(validation_freq, epoch) and
405        not callbacks.model.stop_training):
406
407      if model._compile_distribution:
408        # Since we create a new clone from the original model we need to copy
409        # the weights back to the original model before we can run validation.
410        distributed_training_utils_v1._copy_weights_to_original_model(
411            model, ModeKeys.TRAIN)
412
413      val_results = model_iteration(
414          model,
415          val_inputs,
416          targets=val_targets,
417          sample_weights=val_sample_weights,
418          batch_size=batch_size,
419          steps_per_epoch=validation_steps,
420          callbacks=callbacks,
421          verbose=0,
422          mode=ModeKeys.TEST,
423          validation_in_fit=True,
424          prepared_feed_values_from_dataset=(val_iterator is not None),
425          steps_name='validation_steps')
426      if not isinstance(val_results, list):
427        val_results = [val_results]
428      epoch_logs = cbks.make_logs(
429          model, epoch_logs, val_results, mode, prefix='val_')
430      if val_iterator and epoch < epochs - 1:
431        _reinitialize_iterator(val_iterator, model._distribution_strategy)
432
433    if mode == ModeKeys.TRAIN:
434      # Epochs only apply to `fit`.
435      callbacks.on_epoch_end(epoch, epoch_logs)
436
437    # Reinitialize dataset iterator for the next epoch.
438    if reset_dataset_after_each_epoch and epoch < epochs - 1:
439      _reinitialize_iterator(input_iterator, model._distribution_strategy)
440
441  model._successful_loop_finish = True
442  callbacks._call_end_hook(mode)
443
444  if model._distribution_strategy:
445    if model._compile_distribution:
446      # TODO(priyag, psv): Copy back metrics to the original model as well?
447      distributed_training_utils_v1._copy_weights_to_original_model(model, mode)
448    scope.__exit__(None, None, None)
449
450  if mode == ModeKeys.TRAIN:
451    return model.history
452  return results
453
454
455def _get_model_feed(model, mode):
456  if mode == ModeKeys.PREDICT:
457    feed = model._feed_inputs
458  else:
459    feed = (
460        model._feed_inputs + model._feed_targets + model._feed_sample_weights)
461  return feed
462
463
464def _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset):
465  increment = 'steps' if is_dataset else 'samples'
466  msg = 'Train on {0} {increment}'.format(
467      num_samples_or_steps, increment=increment)
468  if val_samples_or_steps:
469    msg += ', validate on {0} {increment}'.format(
470        val_samples_or_steps, increment=increment)
471  print(msg)
472
473
474def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch):
475  """Returns total number of samples (when training in batch mode) or steps."""
476  if steps_per_epoch:
477    return steps_per_epoch
478  return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch,
479                                             'steps_per_epoch')
480
481
482def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
483  """Prepare feed values to the model execution function.
484
485  Args:
486    model: Model to prepare feed values for.
487    inputs: List or dict of model inputs.
488    targets: Optional list of model targets.
489    sample_weights: Optional list of sample weight arrays.
490    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
491
492  Returns:
493    Feed values for the model in the given mode.
494  """
495  if model._distribution_strategy:
496    if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
497      inputs = distributed_training_utils_v1.get_iterator(
498          inputs, model._distribution_strategy)
499
500    def get_distributed_inputs():
501      return distributed_training_utils_v1._prepare_feed_values(
502          model, inputs, targets, sample_weights, mode)
503
504    # In the eager case, we want to call the input method per step, so return
505    # a lambda from here that can be called. Note that this is applicable only
506    # in Distribution Strategy case as it follows the same code path for both
507    # eager and graph modes.
508    # TODO(priyag,omalleyt): Either we should move the training DS with
509    # IteratorBase to use training_generator code path, or figure out how to
510    # set a symbolic Iterator out of a Dataset when in eager mode.
511    if context.executing_eagerly():
512      return get_distributed_inputs
513    else:
514      return get_distributed_inputs()
515
516  if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
517                         iterator_ops.Iterator)):
518    inputs, targets, sample_weights = model._standardize_user_data(
519        inputs,
520        extract_tensors_from_dataset=True)
521
522  inputs = training_utils_v1.ModelInputs(inputs).as_list()
523  targets = list(targets or [])
524  sample_weights = list(sample_weights or [])
525  ins = inputs + targets + sample_weights
526  if mode == ModeKeys.TRAIN and not isinstance(
527      backend.symbolic_learning_phase(), int):
528    ins += [True]  # Add learning phase value.
529  return ins
530
531
532def _get_iterator(inputs, distribution_strategy=None):
533  if distribution_strategy:
534    return distributed_training_utils_v1.get_iterator(
535        inputs, distribution_strategy)
536  return training_utils_v1.get_iterator(inputs)
537
538
539def _reinitialize_iterator(iterator, distribution_strategy=None):
540  if distribution_strategy:
541    distributed_training_utils_v1.initialize_iterator(
542        iterator, distribution_strategy)
543  else:
544    training_utils_v1.initialize_iterator(iterator)
545
546
547def _make_execution_function(model, mode):
548  """Makes function to run one step of model execution."""
549  if model._distribution_strategy:
550    return distributed_training_utils_v1._make_execution_function(model, mode)
551  return model._make_execution_function(mode)
552
553
554def _update_sample_weight_mode(model, mode, inputs):
555  """Updates the sample_weight_mode of a given model."""
556  # Add a quick return to prevent us from calling model._feed_targets that
557  # accesses certain model properties that may not be set in the `PREDICT` mode.
558  if mode == ModeKeys.PREDICT:
559    return
560
561  sample_weights = None
562  # `inputs` is the model's inputs + targets + sample_weights +
563  # learning phase placeholder if specified. To update the sample_weight_mode
564  # we need to determine if the user has passed sample weights as part of the
565  # input.
566  if not callable(inputs):
567    sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):]
568    has_learning_phase_pl = (mode == ModeKeys.TRAIN and
569                             not isinstance(backend.symbolic_learning_phase(),
570                                            int))
571    if has_learning_phase_pl:
572      sample_weights = sample_weights[:-1]
573    model._update_sample_weight_modes(sample_weights=sample_weights)
574
575  # Call the DistributionStrategy specific function to update the
576  # sample_weight_mode on the model.
577  if model._distribution_strategy:
578    distributed_training_utils_v1._update_sample_weight_modes(model, mode,
579                                                              sample_weights)
580
581# For backwards compatibility for internal users of these loops.
582fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
583test_loop = functools.partial(
584    model_iteration, mode=ModeKeys.TEST, shuffle=False)
585predict_loop = functools.partial(
586    model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
587
588
589class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
590  """TrainingLoop that handle inputs like array.
591
592  This is the default handler for most of the input data types, includes
593  symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
594  (since they generate symbolic tensors). This Function is used to handle model
595  with `run_eagerly` = False.
596  """
597
598  def fit(self,
599          model,
600          x=None,
601          y=None,
602          batch_size=None,
603          epochs=1,
604          verbose=1,
605          callbacks=None,
606          validation_split=0.,
607          validation_data=None,
608          shuffle=True,
609          class_weight=None,
610          sample_weight=None,
611          initial_epoch=0,
612          steps_per_epoch=None,
613          validation_steps=None,
614          validation_freq=1,
615          **kwargs):
616    batch_size = model._validate_or_infer_batch_size(batch_size,
617                                                     steps_per_epoch, x)
618
619    x, y, sample_weights = model._standardize_user_data(
620        x,
621        y,
622        sample_weight=sample_weight,
623        class_weight=class_weight,
624        batch_size=batch_size,
625        check_steps=True,
626        steps_name='steps_per_epoch',
627        steps=steps_per_epoch,
628        validation_split=validation_split,
629        shuffle=shuffle)
630
631    if validation_data:
632      val_x, val_y, val_sample_weights = model._prepare_validation_data(
633          validation_data, batch_size, validation_steps)
634    elif validation_split and 0. < validation_split < 1.:
635      (x, y, sample_weights, val_x, val_y, val_sample_weights
636      ) = training_utils_v1.split_training_and_validation_data(
637          x, y, sample_weights, validation_split)
638    else:
639      if validation_steps:
640        raise ValueError('`validation_steps` should not be specified if '
641                         '`validation_data` is None.')
642      val_x, val_y, val_sample_weights = None, None, None
643
644    return fit_loop(
645        model,
646        inputs=x,
647        targets=y,
648        sample_weights=sample_weights,
649        batch_size=batch_size,
650        epochs=epochs,
651        verbose=verbose,
652        callbacks=callbacks,
653        val_inputs=val_x,
654        val_targets=val_y,
655        val_sample_weights=val_sample_weights,
656        shuffle=shuffle,
657        initial_epoch=initial_epoch,
658        steps_per_epoch=steps_per_epoch,
659        validation_steps=validation_steps,
660        validation_freq=validation_freq,
661        steps_name='steps_per_epoch')
662
663  def evaluate(self,
664               model,
665               x=None,
666               y=None,
667               batch_size=None,
668               verbose=1,
669               sample_weight=None,
670               steps=None,
671               callbacks=None,
672               **kwargs):
673    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
674    x, y, sample_weights = model._standardize_user_data(
675        x,
676        y,
677        sample_weight=sample_weight,
678        batch_size=batch_size,
679        check_steps=True,
680        steps_name='steps',
681        steps=steps)
682    return test_loop(
683        model,
684        inputs=x,
685        targets=y,
686        sample_weights=sample_weights,
687        batch_size=batch_size,
688        verbose=verbose,
689        steps=steps,
690        callbacks=callbacks)
691
692  def predict(self,
693              model,
694              x,
695              batch_size=None,
696              verbose=0,
697              steps=None,
698              callbacks=None,
699              **kwargs):
700    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
701    x, _, _ = model._standardize_user_data(
702        x, check_steps=True, steps_name='steps', steps=steps)
703    return predict_loop(
704        model,
705        x,
706        batch_size=batch_size,
707        verbose=verbose,
708        steps=steps,
709        callbacks=callbacks)
710