xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/training_generator_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to Python generators of array data.
16"""
17# pylint: disable=protected-access
18
19import functools
20import math
21
22import numpy as np
23
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.data.ops import iterator_ops
26from tensorflow.python.eager import context
27from tensorflow.python.framework import errors
28from tensorflow.python.keras import backend
29from tensorflow.python.keras import callbacks as cbks
30from tensorflow.python.keras.engine import training_utils
31from tensorflow.python.keras.engine import training_utils_v1
32from tensorflow.python.keras.utils import data_utils
33from tensorflow.python.keras.utils import generic_utils
34from tensorflow.python.keras.utils.mode_keys import ModeKeys
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util import nest
37
38
39def model_iteration(model,
40                    data,
41                    steps_per_epoch=None,
42                    epochs=1,
43                    verbose=1,
44                    callbacks=None,
45                    validation_data=None,
46                    validation_steps=None,
47                    validation_freq=1,
48                    class_weight=None,
49                    max_queue_size=10,
50                    workers=1,
51                    use_multiprocessing=False,
52                    shuffle=False,
53                    initial_epoch=0,
54                    mode=ModeKeys.TRAIN,
55                    batch_size=None,
56                    steps_name='steps',
57                    **kwargs):
58  """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
59
60  Args:
61      model: Keras Model instance.
62      data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or
63        `(x, y, sample_weights)`) or a generator or
64        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
65      steps_per_epoch: Total number of steps (batches of samples) before
66        declaring one epoch finished and starting the next epoch. Ignored with
67        the default value of `None`.
68      epochs: Number of times to iterate over the data.
69      verbose: 0, 1, or 2. Verbosity mode.
70        0 = silent, 1 = progress bar, 2 = one line per epoch.
71        Note that the progress bar is not particularly useful when
72        logged to a file, so verbose=2 is recommended when not running
73        interactively (eg, in a production environment).
74      callbacks: List of callbacks to be called during training.
75      validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
76        `(x, y)` or `(x, y, sample_weights)`) or a generator or
77        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
78      validation_steps: Total number of steps (batches of samples) before
79        declaring validation finished.
80      validation_freq: Only relevant if validation data is provided. Integer or
81        `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
82        integer, specifies how many training epochs to run before a new
83        validation run is performed, e.g. `validation_freq=2` runs
84        validation every 2 epochs. If a Container, specifies the epochs on
85        which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
86        validation at the end of the 1st, 2nd, and 10th epochs.
87      class_weight: Dictionary mapping class indices to a weight for the class.
88      max_queue_size: Integer. Maximum size for the generator queue. If
89        unspecified, `max_queue_size` will default to 10.
90      workers: Integer. Maximum number of processes to spin up when using
91        process-based threading. If unspecified, `workers` will default to 1. If
92        0, will execute the generator on the main thread.
93      use_multiprocessing: Boolean. If `True`, use process-based threading. If
94        unspecified, `use_multiprocessing` will default to `False`. Note that
95        because this implementation relies on multiprocessing, you should not
96        pass non-picklable arguments to the generator as they can't be passed
97        easily to children processes.
98      shuffle: Boolean. Whether to shuffle the order of the batches at the
99        beginning of each epoch. Only used with instances of `Sequence`
100        (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not
101        `None`.
102      initial_epoch: Epoch at which to start training (useful for resuming a
103        previous training run).
104      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
105      batch_size: Integer batch size or None if unknown. Will only be used if
106        `data` is in NumPy/Tensor format.
107      steps_name: The string name of the steps argument, either `steps`,
108        `validation_steps`, or `steps_per_epoch`. Only used for error message
109        formatting.
110      **kwargs: Additional arguments for backwards compatibility. `steps` is
111        accepted as an alias for `steps_per_epoch`.
112
113  Returns:
114      - In TRAIN mode: `History` object.
115      - In TEST mode: Evaluation metrics.
116      - In PREDICT mode: Outputs of the Model called on inputs.
117
118  Raises:
119      ValueError: in case of invalid arguments.
120  """
121  if 'steps' in kwargs:
122    steps_per_epoch = kwargs['steps']
123
124  # Determine the number of steps per epoch and whether we should reset the
125  # dataset at the end of each epoch.
126  reset_dataset_after_each_epoch = False
127  original_dataset = None
128  is_dataset = isinstance(data, (dataset_ops.DatasetV2, dataset_ops.DatasetV1))
129  if is_dataset:
130    original_dataset = data
131    if steps_per_epoch is None:
132      reset_dataset_after_each_epoch = True
133      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
134          model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
135
136  # Convert to a format that supports `next(generator)`.
137  generator, steps_per_epoch = convert_to_generator_like(
138      data,
139      steps_per_epoch=steps_per_epoch,
140      batch_size=batch_size,
141      epochs=epochs - initial_epoch,
142      shuffle=shuffle)
143
144  do_validation = validation_data is not None
145  is_sequence = isinstance(generator, data_utils.Sequence)
146  _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
147                      steps_per_epoch, validation_data, validation_steps, mode,
148                      kwargs)
149
150  batch_function = _make_execution_function(
151      model, mode, class_weight=class_weight)
152
153  # Create the queue for the generator.
154  enqueuer = None
155  if not is_dataset:
156    generator, enqueuer = _make_enqueued_generator(
157        generator,
158        workers=workers,
159        use_multiprocessing=use_multiprocessing,
160        max_queue_size=max_queue_size,
161        shuffle=shuffle)
162
163  num_samples_or_steps, use_steps = _get_num_samples_or_steps(
164      data, steps_per_epoch)
165
166  count_mode = 'steps' if use_steps else 'samples'
167  callbacks = cbks.configure_callbacks(
168      callbacks,
169      model,
170      do_validation=do_validation,
171      epochs=epochs,
172      steps_per_epoch=steps_per_epoch,
173      batch_size=batch_size,
174      samples=num_samples_or_steps,
175      count_mode=count_mode,
176      verbose=verbose,
177      mode=mode)
178
179  if mode == ModeKeys.PREDICT:
180    aggregator = training_utils_v1.OutputsAggregator(
181        True, steps=steps_per_epoch)
182  else:
183    aggregator = training_utils_v1.MetricsAggregator(
184        True, steps=steps_per_epoch)
185
186  should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
187  if should_set_learning_phase:
188    learning_phase_scope = backend.eager_learning_phase_scope(
189        1 if mode == ModeKeys.TRAIN else 0)
190    learning_phase_scope.__enter__()
191
192  callbacks.model.stop_training = False
193  callbacks._call_begin_hook(mode)
194
195  initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
196
197  for epoch in range(initial_epoch, epochs):
198    if callbacks.model.stop_training:
199      break
200
201    # Setup work for each epoch.
202    model.reset_metrics()
203    epoch_logs = {}
204    if mode == ModeKeys.TRAIN:
205      callbacks.on_epoch_begin(epoch, epoch_logs)
206
207    if steps_per_epoch is None:
208      # Loop over dataset until `OutOfRangeError` is raised.
209      target_steps = np.inf
210    else:
211      # Loop over dataset for the specified number of steps.
212      target_steps = steps_per_epoch
213
214    step = 0
215    while step < target_steps:
216      batch_data = _get_next_batch(generator)
217      if batch_data is None:
218        if is_dataset:
219          # The dataset passed by the user ran out of batches.
220          # Now we know the cardinality of the dataset.
221          # If steps_per_epoch was specified, then running out of data is
222          # unexpected, so we stop training and inform the user.
223          if steps_per_epoch:
224            callbacks.model.stop_training = True
225            logging.warning(
226                'Your dataset ran out of data; interrupting training. '
227                'Make sure that your dataset can generate at least '
228                '`%s * epochs` batches (in this case, %d batches). '
229                'You may need to use the repeat() function when '
230                'building your dataset.'
231                % (steps_name, steps_per_epoch * epochs))
232          elif step > 0:
233            steps_per_epoch = step
234            aggregator.steps = steps_per_epoch
235        else:
236          # We ran out of batches while the user passed an iterator (legacy).
237          callbacks.model.stop_training = True
238          logging.warning(
239              'Your dataset iterator ran out of data; '
240              'interrupting training. Make sure that your iterator '
241              'can generate at least `%s * epochs` '
242              'batches (in this case, %d batches). You may need to'
243              'use the repeat() function when building your '
244              'dataset.' % (steps_name, steps_per_epoch * epochs))
245        break
246
247      # `batch_size` used for validation data if validation
248      # data is NumPy/EagerTensors.
249      batch_size = int(nest.flatten(batch_data)[0].shape[0])
250
251      # Callbacks batch begin.
252      batch_logs = {'batch': step, 'size': batch_size}
253      callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
254
255      is_deferred = not model._is_compiled
256      batch_outs = batch_function(*batch_data)
257      if not isinstance(batch_outs, list):
258        batch_outs = [batch_outs]
259
260      if step == 0:
261        aggregator.create(batch_outs)
262
263        if is_deferred:
264          # Set callbacks params. We do this here when model is compiled only
265          # in the first iteration of this loop (deferred build scenario).
266          cbks.set_callback_parameters(
267              callbacks,
268              model,
269              do_validation=do_validation,
270              batch_size=batch_size,
271              epochs=epochs,
272              steps_per_epoch=steps_per_epoch,
273              samples=num_samples_or_steps,
274              verbose=verbose,
275              mode=mode)
276
277      # Aggregate results.
278      aggregator.aggregate(batch_outs)
279
280      # Callbacks batch end.
281      batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
282      callbacks._call_batch_hook(mode, 'end', step, batch_logs)
283      step += 1
284
285      if callbacks.model.stop_training:
286        break
287
288    aggregator.finalize()
289    results = aggregator.results
290    epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
291    if len(results) == 1:
292      results = results[0]
293
294    # Run the test loop every epoch during training.
295    if (do_validation and
296        training_utils_v1.should_run_validation(validation_freq, epoch) and
297        not callbacks.model.stop_training):
298      val_results = model_iteration(
299          model,
300          validation_data,
301          steps_per_epoch=validation_steps,
302          batch_size=batch_size,
303          class_weight=class_weight,
304          workers=workers,
305          use_multiprocessing=use_multiprocessing,
306          max_queue_size=max_queue_size,
307          callbacks=callbacks,
308          verbose=verbose,
309          mode=ModeKeys.TEST,
310          steps_name='validation_steps')
311
312      if not isinstance(val_results, list):
313        val_results = [val_results]
314      epoch_logs = cbks.make_logs(
315          model, epoch_logs, val_results, mode, prefix='val_')
316
317    if mode == ModeKeys.TRAIN:
318      # Epochs only apply to `fit`.
319      callbacks.on_epoch_end(epoch, epoch_logs)
320
321    # Recreate dataset iterator for the next epoch.
322    if reset_dataset_after_each_epoch and epoch < epochs - 1:
323      generator = dataset_ops.make_one_shot_iterator(original_dataset)
324
325  model._successful_loop_finish = True
326  callbacks._call_end_hook(mode)
327
328  if enqueuer is not None:
329    enqueuer.stop()
330
331  if should_set_learning_phase:
332    learning_phase_scope.__exit__(None, None, None)
333
334  if mode == ModeKeys.TRAIN:
335    return model.history
336  return results
337
338
339# Maintain compatibility with the existing names.
340fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
341evaluate_generator = functools.partial(
342    model_iteration, mode=ModeKeys.TEST, shuffle=False)
343predict_generator = functools.partial(
344    model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
345
346
347def _get_next_batch(generator):
348  """Retrieves the next batch of input data."""
349  try:
350    generator_output = next(generator)
351  except (StopIteration, errors.OutOfRangeError):
352    return None
353
354  if not isinstance(generator_output, tuple):
355    # Always wrap in a tuple.
356    generator_output = (generator_output,)
357  if len(generator_output) not in [1, 2, 3]:
358    raise ValueError(
359        'Output of generator should be a tuple of 1 or 2 or 3 '
360        'elements: (input,) or (input, target) or '
361        '(input, target, sample_weights). Received {}'.format(generator_output))
362  return generator_output
363
364
365def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
366                        steps_per_epoch, validation_data, validation_steps,
367                        mode, kwargs):
368  """Raises errors if arguments are invalid.
369
370  Args:
371    is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence`
372      instance.
373    is_dataset: Boolean, whether data is a dataset instance.
374    use_multiprocessing: Boolean. If `True`, use process-based threading. If
375      unspecified, `use_multiprocessing` will default to `False`. Note that
376      because this implementation relies on multiprocessing, you should not pass
377      non-picklable arguments to the generator as they can't be passed easily to
378      children processes.
379    workers: Integer. Maximum number of processes to spin up when using
380      process-based threading. If unspecified, `workers` will default to 1. If
381      0, will execute the generator on the main thread.
382    steps_per_epoch: Total number of steps (batches of samples) before declaring
383      one epoch finished and starting the next epoch. Ignored with the default
384      value of `None`.
385    validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x,
386      y)` or `(x, y, sample_weights)`) or a generator or
387      `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
388    validation_steps: Total number of steps (batches of samples) before
389      declaring validation finished.
390    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
391    kwargs: Additional arguments for backwards compatibility.
392
393  Raises:
394    ValueError: If `steps_per_epoch` or `validation_steps` are not passed
395      for data types that require them, or if unrecognized keyword
396      arguments are passed.
397  """
398  if not is_sequence and use_multiprocessing and workers > 1:
399    logging.warning(
400        UserWarning('Using a generator with `use_multiprocessing=True`'
401                    ' and multiple workers may duplicate your data.'
402                    ' Please consider using the `keras.utils.Sequence`'
403                    ' class.'))
404
405  if steps_per_epoch is None and not is_dataset:
406    arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps'
407    raise ValueError('Please specify the number of steps via the '
408                     '`{}` argument.'.format(arg_name))
409
410  val_gen = (
411      data_utils.is_generator_or_sequence(validation_data) or
412      isinstance(validation_data, iterator_ops.IteratorBase))
413  if (val_gen and not isinstance(validation_data, data_utils.Sequence) and
414      not validation_steps):
415    raise ValueError('Please specify the `validation_steps` argument.')
416
417  if any(k != 'steps' for k in kwargs):
418    raise ValueError('Invalid arguments passed: {}'.format(
419        [k for k in kwargs if k != 'steps']))
420
421
422def convert_to_generator_like(data,
423                              batch_size=None,
424                              steps_per_epoch=None,
425                              epochs=1,
426                              shuffle=False):
427  """Make a generator out of NumPy or EagerTensor inputs.
428
429  Args:
430    data: Either a generator or `keras.utils.data_utils.Sequence` object or
431      `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or EagerTensors.
432      If a tuple, the elements represent `(x, y, sample_weights)` and may be
433      `None` or `[None]`.
434    batch_size: Used when creating a generator out of tuples of NumPy arrays or
435      EagerTensors.
436    steps_per_epoch: Steps of the generator to run each epoch. If `None` the
437      number of steps will be read from the data (for
438      `keras.utils.data_utils.Sequence` types).
439    epochs: Total number of epochs to run.
440    shuffle: Whether the data should be shuffled.
441
442  Returns:
443    - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`.
444
445  Raises:
446    - ValueError: If `batch_size` is not provided for NumPy or EagerTensor
447      inputs.
448  """
449  if isinstance(data, tuple):
450    # Scrub `Nones` that might have been passed for `targets`, `sample_weights`.
451    data = tuple(
452        ele for ele in data if not all(e is None for e in nest.flatten(ele)))
453
454  if data_utils.is_generator_or_sequence(data) or isinstance(
455      data, iterator_ops.IteratorBase):
456    if isinstance(data, data_utils.Sequence):
457      if steps_per_epoch is None:
458        steps_per_epoch = len(data)
459    return data, steps_per_epoch
460  if isinstance(data, dataset_ops.DatasetV2):
461    return dataset_ops.make_one_shot_iterator(data), steps_per_epoch
462
463  # Create generator from NumPy or EagerTensor Input.
464  num_samples = int(nest.flatten(data)[0].shape[0])
465  if batch_size is None:
466    raise ValueError(
467        'When passing input data as arrays, do not specify '
468        '`steps_per_epoch`/`steps` argument. Please use `batch_size` instead.')
469  steps_per_epoch = int(math.ceil(num_samples / batch_size))
470
471  def _gen(data):
472    """Makes a generator out of a structure of NumPy/EagerTensors."""
473    index_array = np.arange(num_samples)
474    for _ in range(epochs):
475      if shuffle:
476        np.random.shuffle(index_array)
477      batches = generic_utils.make_batches(num_samples, batch_size)
478      for (batch_start, batch_end) in batches:
479        batch_ids = index_array[batch_start:batch_end]
480        flat_batch_data = training_utils.slice_arrays(
481            nest.flatten(data), batch_ids, contiguous=(not shuffle))
482        yield nest.pack_sequence_as(data, flat_batch_data)
483
484  return _gen(data), steps_per_epoch
485
486
487def _make_enqueued_generator(generator,
488                             workers=1,
489                             use_multiprocessing=False,
490                             max_queue_size=10,
491                             shuffle=False):
492  """Create a buffered queue of next elements of the generator."""
493  is_sequence = isinstance(generator, data_utils.Sequence)
494  enqueuer = None
495  if workers > 0:
496    if is_sequence:
497      enqueuer = data_utils.OrderedEnqueuer(
498          generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle)
499    else:
500      enqueuer = data_utils.GeneratorEnqueuer(
501          generator, use_multiprocessing=use_multiprocessing)
502    enqueuer.start(workers=workers, max_queue_size=max_queue_size)
503    output_generator = enqueuer.get()
504  else:
505    if is_sequence:
506      output_generator = data_utils.iter_sequence_infinite(generator)
507    else:
508      output_generator = generator
509  return output_generator, enqueuer
510
511
512def _make_execution_function(model, mode, class_weight=None):
513  """Makes function to run one step of model execution."""
514  if mode == ModeKeys.TRAIN:
515    f = functools.partial(model.train_on_batch, class_weight=class_weight)
516  elif mode == ModeKeys.TEST:
517    f = model.test_on_batch
518  else:
519    # Match signature of other modes to allow
520    # 1, 2, or 3-tuples from generator
521    def predict_on_batch(x, y=None, sample_weights=None):  # pylint: disable=unused-argument
522      return model.predict_on_batch(x)
523
524    f = predict_on_batch
525
526  # Maintain stateful metrics across batch-level calls.
527  if mode != ModeKeys.PREDICT:
528    f = functools.partial(f, reset_metrics=False)
529
530  return f
531
532
533def _get_num_samples_or_steps(data, steps_per_epoch):
534  """Returns number of samples or steps, and whether to use steps count mode."""
535  flat_inputs = nest.flatten(data)
536  if hasattr(flat_inputs[0], 'shape'):
537    return int(flat_inputs[0].shape[0]), False
538  return steps_per_epoch, True
539
540
541class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
542  """Generator-like.
543
544  Input is Python generator, or Sequence object.
545
546  The difference between this class and `GeneratorLikeTrainingFunction` is that
547  this class only handles inputs that with x, y and sample_weight fused into one
548  param.
549  """
550
551  def fit(self,
552          model,
553          x=None,
554          y=None,
555          batch_size=None,
556          epochs=1,
557          verbose=1,
558          callbacks=None,
559          validation_split=0.,
560          validation_data=None,
561          shuffle=True,
562          class_weight=None,
563          sample_weight=None,
564          initial_epoch=0,
565          steps_per_epoch=None,
566          validation_steps=None,
567          validation_freq=1,
568          max_queue_size=10,
569          workers=1,
570          use_multiprocessing=False):
571    model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
572    training_utils_v1.check_generator_arguments(
573        y, sample_weight, validation_split=validation_split)
574    return fit_generator(
575        model,
576        x,
577        steps_per_epoch=steps_per_epoch,
578        epochs=epochs,
579        verbose=verbose,
580        callbacks=callbacks,
581        validation_data=validation_data,
582        validation_steps=validation_steps,
583        validation_freq=validation_freq,
584        class_weight=class_weight,
585        max_queue_size=max_queue_size,
586        workers=workers,
587        use_multiprocessing=use_multiprocessing,
588        shuffle=shuffle,
589        initial_epoch=initial_epoch,
590        steps_name='steps_per_epoch')
591
592  def evaluate(self,
593               model,
594               x=None,
595               y=None,
596               batch_size=None,
597               verbose=1,
598               sample_weight=None,
599               steps=None,
600               callbacks=None,
601               max_queue_size=10,
602               workers=1,
603               use_multiprocessing=False):
604    model._validate_or_infer_batch_size(batch_size, steps, x)
605    training_utils_v1.check_generator_arguments(y, sample_weight)
606    return evaluate_generator(
607        model,
608        x,
609        steps=steps,
610        verbose=verbose,
611        callbacks=callbacks,
612        max_queue_size=max_queue_size,
613        workers=workers,
614        use_multiprocessing=use_multiprocessing)
615
616  def predict(self,
617              model,
618              x,
619              batch_size=None,
620              verbose=0,
621              steps=None,
622              callbacks=None,
623              max_queue_size=10,
624              workers=1,
625              use_multiprocessing=False):
626    model._validate_or_infer_batch_size(batch_size, steps, x)
627    return predict_generator(
628        model,
629        x,
630        steps=steps,
631        verbose=verbose,
632        callbacks=callbacks,
633        max_queue_size=max_queue_size,
634        workers=workers,
635        use_multiprocessing=use_multiprocessing)
636
637
638class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
639  """A non-distributed Dataset or iterator in eager execution."""
640
641  def fit(self,
642          model,
643          x=None,
644          y=None,
645          batch_size=None,
646          epochs=1,
647          verbose=1,
648          callbacks=None,
649          validation_split=0.,
650          validation_data=None,
651          shuffle=True,
652          class_weight=None,
653          sample_weight=None,
654          initial_epoch=0,
655          steps_per_epoch=None,
656          validation_steps=None,
657          validation_freq=1,
658          **kwargs):
659    model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
660    # Make sure that y, sample_weights, validation_split are not passed.
661    training_utils_v1.validate_dataset_input(x, y, sample_weight,
662                                             validation_split)
663    if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and
664        shuffle):
665      training_utils_v1.verify_dataset_shuffled(x)
666
667    return fit_generator(
668        model,
669        x,
670        steps_per_epoch=steps_per_epoch,
671        epochs=epochs,
672        verbose=verbose,
673        callbacks=callbacks,
674        validation_data=validation_data,
675        validation_steps=validation_steps,
676        validation_freq=validation_freq,
677        class_weight=class_weight,
678        workers=0,
679        shuffle=shuffle,
680        initial_epoch=initial_epoch,
681        steps_name='steps_per_epoch')
682
683  def evaluate(self,
684               model,
685               x=None,
686               y=None,
687               batch_size=None,
688               verbose=1,
689               sample_weight=None,
690               steps=None,
691               callbacks=None,
692               **kwargs):
693    model._validate_or_infer_batch_size(batch_size, steps, x)
694    # Make sure that y, sample_weights, validation_split are not passed.
695    training_utils_v1.validate_dataset_input(x, y, sample_weight)
696    return evaluate_generator(
697        model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
698
699  def predict(self,
700              model,
701              x,
702              batch_size=None,
703              verbose=0,
704              steps=None,
705              callbacks=None,
706              **kwargs):
707    model._validate_or_infer_batch_size(batch_size, steps, x)
708    return predict_generator(
709        model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
710
711
712class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
713  """TrainingLoop that handle inputs like python generator.
714
715  This is the default handler for most of the input data types, includes
716  symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
717  (since they generate symbolic tensors). This Function is used to handle model
718  with `run_eagerly` = True.
719  """
720
721  def fit(self,
722          model,
723          x=None,
724          y=None,
725          batch_size=None,
726          epochs=1,
727          verbose=1,
728          callbacks=None,
729          validation_split=0.,
730          validation_data=None,
731          shuffle=True,
732          class_weight=None,
733          sample_weight=None,
734          initial_epoch=0,
735          steps_per_epoch=None,
736          validation_steps=None,
737          validation_freq=1,
738          **kwargs):
739    batch_size = model._validate_or_infer_batch_size(batch_size,
740                                                     steps_per_epoch, x)
741    x, y, sample_weights = model._standardize_user_data(
742        x,
743        y,
744        sample_weight=sample_weight,
745        class_weight=class_weight,
746        batch_size=batch_size,
747        check_steps=True,
748        steps_name='steps_per_epoch',
749        steps=steps_per_epoch,
750        validation_split=validation_split,
751        shuffle=shuffle)
752
753    if validation_data:
754      validation_data = model._prepare_validation_data(validation_data,
755                                                       batch_size,
756                                                       validation_steps)
757    elif validation_split and 0. < validation_split < 1.:
758      (x, y, sample_weights, val_x, val_y,
759       val_sample_weights) = (
760           training_utils_v1.split_training_and_validation_data(
761               x, y, sample_weights, validation_split))
762      validation_data = (val_x, val_y, val_sample_weights)
763    else:
764      if validation_steps:
765        raise ValueError('`validation_steps` should not be specified if '
766                         '`validation_data` is None.')
767
768    return fit_generator(
769        model, (x, y, sample_weights),
770        steps_per_epoch=steps_per_epoch,
771        batch_size=batch_size,
772        epochs=epochs,
773        verbose=verbose,
774        callbacks=callbacks,
775        validation_data=validation_data,
776        validation_steps=validation_steps,
777        validation_freq=validation_freq,
778        workers=0,
779        shuffle=shuffle,
780        initial_epoch=initial_epoch,
781        steps_name='steps_per_epoch')
782
783  def evaluate(self,
784               model,
785               x=None,
786               y=None,
787               batch_size=None,
788               verbose=1,
789               sample_weight=None,
790               steps=None,
791               callbacks=None,
792               **kwargs):
793    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
794    x, y, sample_weights = model._standardize_user_data(
795        x,
796        y,
797        sample_weight=sample_weight,
798        batch_size=batch_size,
799        check_steps=True,
800        steps_name='steps',
801        steps=steps)
802    return evaluate_generator(
803        model, (x, y, sample_weights),
804        steps=steps,
805        batch_size=batch_size,
806        verbose=verbose,
807        workers=0,
808        callbacks=callbacks)
809
810  def predict(self,
811              model,
812              x,
813              batch_size=None,
814              verbose=0,
815              steps=None,
816              callbacks=None,
817              **kwargs):
818    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
819    x, _, _ = model._standardize_user_data(
820        x, check_steps=True, steps_name='steps', steps=steps)
821    return predict_generator(
822        model,
823        x,
824        steps=steps,
825        batch_size=batch_size,
826        verbose=verbose,
827        workers=0,
828        callbacks=callbacks)
829