1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Part of the Keras training engine related to Python generators of array data. 16""" 17# pylint: disable=protected-access 18 19import functools 20import math 21 22import numpy as np 23 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.data.ops import iterator_ops 26from tensorflow.python.eager import context 27from tensorflow.python.framework import errors 28from tensorflow.python.keras import backend 29from tensorflow.python.keras import callbacks as cbks 30from tensorflow.python.keras.engine import training_utils 31from tensorflow.python.keras.engine import training_utils_v1 32from tensorflow.python.keras.utils import data_utils 33from tensorflow.python.keras.utils import generic_utils 34from tensorflow.python.keras.utils.mode_keys import ModeKeys 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.util import nest 37 38 39def model_iteration(model, 40 data, 41 steps_per_epoch=None, 42 epochs=1, 43 verbose=1, 44 callbacks=None, 45 validation_data=None, 46 validation_steps=None, 47 validation_freq=1, 48 class_weight=None, 49 max_queue_size=10, 50 workers=1, 51 use_multiprocessing=False, 52 shuffle=False, 53 initial_epoch=0, 54 mode=ModeKeys.TRAIN, 55 batch_size=None, 56 steps_name='steps', 57 **kwargs): 58 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 59 60 Args: 61 model: Keras Model instance. 62 data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or 63 `(x, y, sample_weights)`) or a generator or 64 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 65 steps_per_epoch: Total number of steps (batches of samples) before 66 declaring one epoch finished and starting the next epoch. Ignored with 67 the default value of `None`. 68 epochs: Number of times to iterate over the data. 69 verbose: 0, 1, or 2. Verbosity mode. 70 0 = silent, 1 = progress bar, 2 = one line per epoch. 71 Note that the progress bar is not particularly useful when 72 logged to a file, so verbose=2 is recommended when not running 73 interactively (eg, in a production environment). 74 callbacks: List of callbacks to be called during training. 75 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or 76 `(x, y)` or `(x, y, sample_weights)`) or a generator or 77 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 78 validation_steps: Total number of steps (batches of samples) before 79 declaring validation finished. 80 validation_freq: Only relevant if validation data is provided. Integer or 81 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 82 integer, specifies how many training epochs to run before a new 83 validation run is performed, e.g. `validation_freq=2` runs 84 validation every 2 epochs. If a Container, specifies the epochs on 85 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 86 validation at the end of the 1st, 2nd, and 10th epochs. 87 class_weight: Dictionary mapping class indices to a weight for the class. 88 max_queue_size: Integer. Maximum size for the generator queue. If 89 unspecified, `max_queue_size` will default to 10. 90 workers: Integer. Maximum number of processes to spin up when using 91 process-based threading. If unspecified, `workers` will default to 1. If 92 0, will execute the generator on the main thread. 93 use_multiprocessing: Boolean. If `True`, use process-based threading. If 94 unspecified, `use_multiprocessing` will default to `False`. Note that 95 because this implementation relies on multiprocessing, you should not 96 pass non-picklable arguments to the generator as they can't be passed 97 easily to children processes. 98 shuffle: Boolean. Whether to shuffle the order of the batches at the 99 beginning of each epoch. Only used with instances of `Sequence` 100 (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not 101 `None`. 102 initial_epoch: Epoch at which to start training (useful for resuming a 103 previous training run). 104 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 105 batch_size: Integer batch size or None if unknown. Will only be used if 106 `data` is in NumPy/Tensor format. 107 steps_name: The string name of the steps argument, either `steps`, 108 `validation_steps`, or `steps_per_epoch`. Only used for error message 109 formatting. 110 **kwargs: Additional arguments for backwards compatibility. `steps` is 111 accepted as an alias for `steps_per_epoch`. 112 113 Returns: 114 - In TRAIN mode: `History` object. 115 - In TEST mode: Evaluation metrics. 116 - In PREDICT mode: Outputs of the Model called on inputs. 117 118 Raises: 119 ValueError: in case of invalid arguments. 120 """ 121 if 'steps' in kwargs: 122 steps_per_epoch = kwargs['steps'] 123 124 # Determine the number of steps per epoch and whether we should reset the 125 # dataset at the end of each epoch. 126 reset_dataset_after_each_epoch = False 127 original_dataset = None 128 is_dataset = isinstance(data, (dataset_ops.DatasetV2, dataset_ops.DatasetV1)) 129 if is_dataset: 130 original_dataset = data 131 if steps_per_epoch is None: 132 reset_dataset_after_each_epoch = True 133 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 134 model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name) 135 136 # Convert to a format that supports `next(generator)`. 137 generator, steps_per_epoch = convert_to_generator_like( 138 data, 139 steps_per_epoch=steps_per_epoch, 140 batch_size=batch_size, 141 epochs=epochs - initial_epoch, 142 shuffle=shuffle) 143 144 do_validation = validation_data is not None 145 is_sequence = isinstance(generator, data_utils.Sequence) 146 _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 147 steps_per_epoch, validation_data, validation_steps, mode, 148 kwargs) 149 150 batch_function = _make_execution_function( 151 model, mode, class_weight=class_weight) 152 153 # Create the queue for the generator. 154 enqueuer = None 155 if not is_dataset: 156 generator, enqueuer = _make_enqueued_generator( 157 generator, 158 workers=workers, 159 use_multiprocessing=use_multiprocessing, 160 max_queue_size=max_queue_size, 161 shuffle=shuffle) 162 163 num_samples_or_steps, use_steps = _get_num_samples_or_steps( 164 data, steps_per_epoch) 165 166 count_mode = 'steps' if use_steps else 'samples' 167 callbacks = cbks.configure_callbacks( 168 callbacks, 169 model, 170 do_validation=do_validation, 171 epochs=epochs, 172 steps_per_epoch=steps_per_epoch, 173 batch_size=batch_size, 174 samples=num_samples_or_steps, 175 count_mode=count_mode, 176 verbose=verbose, 177 mode=mode) 178 179 if mode == ModeKeys.PREDICT: 180 aggregator = training_utils_v1.OutputsAggregator( 181 True, steps=steps_per_epoch) 182 else: 183 aggregator = training_utils_v1.MetricsAggregator( 184 True, steps=steps_per_epoch) 185 186 should_set_learning_phase = context.executing_eagerly() and model.run_eagerly 187 if should_set_learning_phase: 188 learning_phase_scope = backend.eager_learning_phase_scope( 189 1 if mode == ModeKeys.TRAIN else 0) 190 learning_phase_scope.__enter__() 191 192 callbacks.model.stop_training = False 193 callbacks._call_begin_hook(mode) 194 195 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 196 197 for epoch in range(initial_epoch, epochs): 198 if callbacks.model.stop_training: 199 break 200 201 # Setup work for each epoch. 202 model.reset_metrics() 203 epoch_logs = {} 204 if mode == ModeKeys.TRAIN: 205 callbacks.on_epoch_begin(epoch, epoch_logs) 206 207 if steps_per_epoch is None: 208 # Loop over dataset until `OutOfRangeError` is raised. 209 target_steps = np.inf 210 else: 211 # Loop over dataset for the specified number of steps. 212 target_steps = steps_per_epoch 213 214 step = 0 215 while step < target_steps: 216 batch_data = _get_next_batch(generator) 217 if batch_data is None: 218 if is_dataset: 219 # The dataset passed by the user ran out of batches. 220 # Now we know the cardinality of the dataset. 221 # If steps_per_epoch was specified, then running out of data is 222 # unexpected, so we stop training and inform the user. 223 if steps_per_epoch: 224 callbacks.model.stop_training = True 225 logging.warning( 226 'Your dataset ran out of data; interrupting training. ' 227 'Make sure that your dataset can generate at least ' 228 '`%s * epochs` batches (in this case, %d batches). ' 229 'You may need to use the repeat() function when ' 230 'building your dataset.' 231 % (steps_name, steps_per_epoch * epochs)) 232 elif step > 0: 233 steps_per_epoch = step 234 aggregator.steps = steps_per_epoch 235 else: 236 # We ran out of batches while the user passed an iterator (legacy). 237 callbacks.model.stop_training = True 238 logging.warning( 239 'Your dataset iterator ran out of data; ' 240 'interrupting training. Make sure that your iterator ' 241 'can generate at least `%s * epochs` ' 242 'batches (in this case, %d batches). You may need to' 243 'use the repeat() function when building your ' 244 'dataset.' % (steps_name, steps_per_epoch * epochs)) 245 break 246 247 # `batch_size` used for validation data if validation 248 # data is NumPy/EagerTensors. 249 batch_size = int(nest.flatten(batch_data)[0].shape[0]) 250 251 # Callbacks batch begin. 252 batch_logs = {'batch': step, 'size': batch_size} 253 callbacks._call_batch_hook(mode, 'begin', step, batch_logs) 254 255 is_deferred = not model._is_compiled 256 batch_outs = batch_function(*batch_data) 257 if not isinstance(batch_outs, list): 258 batch_outs = [batch_outs] 259 260 if step == 0: 261 aggregator.create(batch_outs) 262 263 if is_deferred: 264 # Set callbacks params. We do this here when model is compiled only 265 # in the first iteration of this loop (deferred build scenario). 266 cbks.set_callback_parameters( 267 callbacks, 268 model, 269 do_validation=do_validation, 270 batch_size=batch_size, 271 epochs=epochs, 272 steps_per_epoch=steps_per_epoch, 273 samples=num_samples_or_steps, 274 verbose=verbose, 275 mode=mode) 276 277 # Aggregate results. 278 aggregator.aggregate(batch_outs) 279 280 # Callbacks batch end. 281 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 282 callbacks._call_batch_hook(mode, 'end', step, batch_logs) 283 step += 1 284 285 if callbacks.model.stop_training: 286 break 287 288 aggregator.finalize() 289 results = aggregator.results 290 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) 291 if len(results) == 1: 292 results = results[0] 293 294 # Run the test loop every epoch during training. 295 if (do_validation and 296 training_utils_v1.should_run_validation(validation_freq, epoch) and 297 not callbacks.model.stop_training): 298 val_results = model_iteration( 299 model, 300 validation_data, 301 steps_per_epoch=validation_steps, 302 batch_size=batch_size, 303 class_weight=class_weight, 304 workers=workers, 305 use_multiprocessing=use_multiprocessing, 306 max_queue_size=max_queue_size, 307 callbacks=callbacks, 308 verbose=verbose, 309 mode=ModeKeys.TEST, 310 steps_name='validation_steps') 311 312 if not isinstance(val_results, list): 313 val_results = [val_results] 314 epoch_logs = cbks.make_logs( 315 model, epoch_logs, val_results, mode, prefix='val_') 316 317 if mode == ModeKeys.TRAIN: 318 # Epochs only apply to `fit`. 319 callbacks.on_epoch_end(epoch, epoch_logs) 320 321 # Recreate dataset iterator for the next epoch. 322 if reset_dataset_after_each_epoch and epoch < epochs - 1: 323 generator = dataset_ops.make_one_shot_iterator(original_dataset) 324 325 model._successful_loop_finish = True 326 callbacks._call_end_hook(mode) 327 328 if enqueuer is not None: 329 enqueuer.stop() 330 331 if should_set_learning_phase: 332 learning_phase_scope.__exit__(None, None, None) 333 334 if mode == ModeKeys.TRAIN: 335 return model.history 336 return results 337 338 339# Maintain compatibility with the existing names. 340fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 341evaluate_generator = functools.partial( 342 model_iteration, mode=ModeKeys.TEST, shuffle=False) 343predict_generator = functools.partial( 344 model_iteration, mode=ModeKeys.PREDICT, shuffle=False) 345 346 347def _get_next_batch(generator): 348 """Retrieves the next batch of input data.""" 349 try: 350 generator_output = next(generator) 351 except (StopIteration, errors.OutOfRangeError): 352 return None 353 354 if not isinstance(generator_output, tuple): 355 # Always wrap in a tuple. 356 generator_output = (generator_output,) 357 if len(generator_output) not in [1, 2, 3]: 358 raise ValueError( 359 'Output of generator should be a tuple of 1 or 2 or 3 ' 360 'elements: (input,) or (input, target) or ' 361 '(input, target, sample_weights). Received {}'.format(generator_output)) 362 return generator_output 363 364 365def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 366 steps_per_epoch, validation_data, validation_steps, 367 mode, kwargs): 368 """Raises errors if arguments are invalid. 369 370 Args: 371 is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence` 372 instance. 373 is_dataset: Boolean, whether data is a dataset instance. 374 use_multiprocessing: Boolean. If `True`, use process-based threading. If 375 unspecified, `use_multiprocessing` will default to `False`. Note that 376 because this implementation relies on multiprocessing, you should not pass 377 non-picklable arguments to the generator as they can't be passed easily to 378 children processes. 379 workers: Integer. Maximum number of processes to spin up when using 380 process-based threading. If unspecified, `workers` will default to 1. If 381 0, will execute the generator on the main thread. 382 steps_per_epoch: Total number of steps (batches of samples) before declaring 383 one epoch finished and starting the next epoch. Ignored with the default 384 value of `None`. 385 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, 386 y)` or `(x, y, sample_weights)`) or a generator or 387 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 388 validation_steps: Total number of steps (batches of samples) before 389 declaring validation finished. 390 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 391 kwargs: Additional arguments for backwards compatibility. 392 393 Raises: 394 ValueError: If `steps_per_epoch` or `validation_steps` are not passed 395 for data types that require them, or if unrecognized keyword 396 arguments are passed. 397 """ 398 if not is_sequence and use_multiprocessing and workers > 1: 399 logging.warning( 400 UserWarning('Using a generator with `use_multiprocessing=True`' 401 ' and multiple workers may duplicate your data.' 402 ' Please consider using the `keras.utils.Sequence`' 403 ' class.')) 404 405 if steps_per_epoch is None and not is_dataset: 406 arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps' 407 raise ValueError('Please specify the number of steps via the ' 408 '`{}` argument.'.format(arg_name)) 409 410 val_gen = ( 411 data_utils.is_generator_or_sequence(validation_data) or 412 isinstance(validation_data, iterator_ops.IteratorBase)) 413 if (val_gen and not isinstance(validation_data, data_utils.Sequence) and 414 not validation_steps): 415 raise ValueError('Please specify the `validation_steps` argument.') 416 417 if any(k != 'steps' for k in kwargs): 418 raise ValueError('Invalid arguments passed: {}'.format( 419 [k for k in kwargs if k != 'steps'])) 420 421 422def convert_to_generator_like(data, 423 batch_size=None, 424 steps_per_epoch=None, 425 epochs=1, 426 shuffle=False): 427 """Make a generator out of NumPy or EagerTensor inputs. 428 429 Args: 430 data: Either a generator or `keras.utils.data_utils.Sequence` object or 431 `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or EagerTensors. 432 If a tuple, the elements represent `(x, y, sample_weights)` and may be 433 `None` or `[None]`. 434 batch_size: Used when creating a generator out of tuples of NumPy arrays or 435 EagerTensors. 436 steps_per_epoch: Steps of the generator to run each epoch. If `None` the 437 number of steps will be read from the data (for 438 `keras.utils.data_utils.Sequence` types). 439 epochs: Total number of epochs to run. 440 shuffle: Whether the data should be shuffled. 441 442 Returns: 443 - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`. 444 445 Raises: 446 - ValueError: If `batch_size` is not provided for NumPy or EagerTensor 447 inputs. 448 """ 449 if isinstance(data, tuple): 450 # Scrub `Nones` that might have been passed for `targets`, `sample_weights`. 451 data = tuple( 452 ele for ele in data if not all(e is None for e in nest.flatten(ele))) 453 454 if data_utils.is_generator_or_sequence(data) or isinstance( 455 data, iterator_ops.IteratorBase): 456 if isinstance(data, data_utils.Sequence): 457 if steps_per_epoch is None: 458 steps_per_epoch = len(data) 459 return data, steps_per_epoch 460 if isinstance(data, dataset_ops.DatasetV2): 461 return dataset_ops.make_one_shot_iterator(data), steps_per_epoch 462 463 # Create generator from NumPy or EagerTensor Input. 464 num_samples = int(nest.flatten(data)[0].shape[0]) 465 if batch_size is None: 466 raise ValueError( 467 'When passing input data as arrays, do not specify ' 468 '`steps_per_epoch`/`steps` argument. Please use `batch_size` instead.') 469 steps_per_epoch = int(math.ceil(num_samples / batch_size)) 470 471 def _gen(data): 472 """Makes a generator out of a structure of NumPy/EagerTensors.""" 473 index_array = np.arange(num_samples) 474 for _ in range(epochs): 475 if shuffle: 476 np.random.shuffle(index_array) 477 batches = generic_utils.make_batches(num_samples, batch_size) 478 for (batch_start, batch_end) in batches: 479 batch_ids = index_array[batch_start:batch_end] 480 flat_batch_data = training_utils.slice_arrays( 481 nest.flatten(data), batch_ids, contiguous=(not shuffle)) 482 yield nest.pack_sequence_as(data, flat_batch_data) 483 484 return _gen(data), steps_per_epoch 485 486 487def _make_enqueued_generator(generator, 488 workers=1, 489 use_multiprocessing=False, 490 max_queue_size=10, 491 shuffle=False): 492 """Create a buffered queue of next elements of the generator.""" 493 is_sequence = isinstance(generator, data_utils.Sequence) 494 enqueuer = None 495 if workers > 0: 496 if is_sequence: 497 enqueuer = data_utils.OrderedEnqueuer( 498 generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle) 499 else: 500 enqueuer = data_utils.GeneratorEnqueuer( 501 generator, use_multiprocessing=use_multiprocessing) 502 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 503 output_generator = enqueuer.get() 504 else: 505 if is_sequence: 506 output_generator = data_utils.iter_sequence_infinite(generator) 507 else: 508 output_generator = generator 509 return output_generator, enqueuer 510 511 512def _make_execution_function(model, mode, class_weight=None): 513 """Makes function to run one step of model execution.""" 514 if mode == ModeKeys.TRAIN: 515 f = functools.partial(model.train_on_batch, class_weight=class_weight) 516 elif mode == ModeKeys.TEST: 517 f = model.test_on_batch 518 else: 519 # Match signature of other modes to allow 520 # 1, 2, or 3-tuples from generator 521 def predict_on_batch(x, y=None, sample_weights=None): # pylint: disable=unused-argument 522 return model.predict_on_batch(x) 523 524 f = predict_on_batch 525 526 # Maintain stateful metrics across batch-level calls. 527 if mode != ModeKeys.PREDICT: 528 f = functools.partial(f, reset_metrics=False) 529 530 return f 531 532 533def _get_num_samples_or_steps(data, steps_per_epoch): 534 """Returns number of samples or steps, and whether to use steps count mode.""" 535 flat_inputs = nest.flatten(data) 536 if hasattr(flat_inputs[0], 'shape'): 537 return int(flat_inputs[0].shape[0]), False 538 return steps_per_epoch, True 539 540 541class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop): 542 """Generator-like. 543 544 Input is Python generator, or Sequence object. 545 546 The difference between this class and `GeneratorLikeTrainingFunction` is that 547 this class only handles inputs that with x, y and sample_weight fused into one 548 param. 549 """ 550 551 def fit(self, 552 model, 553 x=None, 554 y=None, 555 batch_size=None, 556 epochs=1, 557 verbose=1, 558 callbacks=None, 559 validation_split=0., 560 validation_data=None, 561 shuffle=True, 562 class_weight=None, 563 sample_weight=None, 564 initial_epoch=0, 565 steps_per_epoch=None, 566 validation_steps=None, 567 validation_freq=1, 568 max_queue_size=10, 569 workers=1, 570 use_multiprocessing=False): 571 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 572 training_utils_v1.check_generator_arguments( 573 y, sample_weight, validation_split=validation_split) 574 return fit_generator( 575 model, 576 x, 577 steps_per_epoch=steps_per_epoch, 578 epochs=epochs, 579 verbose=verbose, 580 callbacks=callbacks, 581 validation_data=validation_data, 582 validation_steps=validation_steps, 583 validation_freq=validation_freq, 584 class_weight=class_weight, 585 max_queue_size=max_queue_size, 586 workers=workers, 587 use_multiprocessing=use_multiprocessing, 588 shuffle=shuffle, 589 initial_epoch=initial_epoch, 590 steps_name='steps_per_epoch') 591 592 def evaluate(self, 593 model, 594 x=None, 595 y=None, 596 batch_size=None, 597 verbose=1, 598 sample_weight=None, 599 steps=None, 600 callbacks=None, 601 max_queue_size=10, 602 workers=1, 603 use_multiprocessing=False): 604 model._validate_or_infer_batch_size(batch_size, steps, x) 605 training_utils_v1.check_generator_arguments(y, sample_weight) 606 return evaluate_generator( 607 model, 608 x, 609 steps=steps, 610 verbose=verbose, 611 callbacks=callbacks, 612 max_queue_size=max_queue_size, 613 workers=workers, 614 use_multiprocessing=use_multiprocessing) 615 616 def predict(self, 617 model, 618 x, 619 batch_size=None, 620 verbose=0, 621 steps=None, 622 callbacks=None, 623 max_queue_size=10, 624 workers=1, 625 use_multiprocessing=False): 626 model._validate_or_infer_batch_size(batch_size, steps, x) 627 return predict_generator( 628 model, 629 x, 630 steps=steps, 631 verbose=verbose, 632 callbacks=callbacks, 633 max_queue_size=max_queue_size, 634 workers=workers, 635 use_multiprocessing=use_multiprocessing) 636 637 638class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop): 639 """A non-distributed Dataset or iterator in eager execution.""" 640 641 def fit(self, 642 model, 643 x=None, 644 y=None, 645 batch_size=None, 646 epochs=1, 647 verbose=1, 648 callbacks=None, 649 validation_split=0., 650 validation_data=None, 651 shuffle=True, 652 class_weight=None, 653 sample_weight=None, 654 initial_epoch=0, 655 steps_per_epoch=None, 656 validation_steps=None, 657 validation_freq=1, 658 **kwargs): 659 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 660 # Make sure that y, sample_weights, validation_split are not passed. 661 training_utils_v1.validate_dataset_input(x, y, sample_weight, 662 validation_split) 663 if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and 664 shuffle): 665 training_utils_v1.verify_dataset_shuffled(x) 666 667 return fit_generator( 668 model, 669 x, 670 steps_per_epoch=steps_per_epoch, 671 epochs=epochs, 672 verbose=verbose, 673 callbacks=callbacks, 674 validation_data=validation_data, 675 validation_steps=validation_steps, 676 validation_freq=validation_freq, 677 class_weight=class_weight, 678 workers=0, 679 shuffle=shuffle, 680 initial_epoch=initial_epoch, 681 steps_name='steps_per_epoch') 682 683 def evaluate(self, 684 model, 685 x=None, 686 y=None, 687 batch_size=None, 688 verbose=1, 689 sample_weight=None, 690 steps=None, 691 callbacks=None, 692 **kwargs): 693 model._validate_or_infer_batch_size(batch_size, steps, x) 694 # Make sure that y, sample_weights, validation_split are not passed. 695 training_utils_v1.validate_dataset_input(x, y, sample_weight) 696 return evaluate_generator( 697 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) 698 699 def predict(self, 700 model, 701 x, 702 batch_size=None, 703 verbose=0, 704 steps=None, 705 callbacks=None, 706 **kwargs): 707 model._validate_or_infer_batch_size(batch_size, steps, x) 708 return predict_generator( 709 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) 710 711 712class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop): 713 """TrainingLoop that handle inputs like python generator. 714 715 This is the default handler for most of the input data types, includes 716 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode 717 (since they generate symbolic tensors). This Function is used to handle model 718 with `run_eagerly` = True. 719 """ 720 721 def fit(self, 722 model, 723 x=None, 724 y=None, 725 batch_size=None, 726 epochs=1, 727 verbose=1, 728 callbacks=None, 729 validation_split=0., 730 validation_data=None, 731 shuffle=True, 732 class_weight=None, 733 sample_weight=None, 734 initial_epoch=0, 735 steps_per_epoch=None, 736 validation_steps=None, 737 validation_freq=1, 738 **kwargs): 739 batch_size = model._validate_or_infer_batch_size(batch_size, 740 steps_per_epoch, x) 741 x, y, sample_weights = model._standardize_user_data( 742 x, 743 y, 744 sample_weight=sample_weight, 745 class_weight=class_weight, 746 batch_size=batch_size, 747 check_steps=True, 748 steps_name='steps_per_epoch', 749 steps=steps_per_epoch, 750 validation_split=validation_split, 751 shuffle=shuffle) 752 753 if validation_data: 754 validation_data = model._prepare_validation_data(validation_data, 755 batch_size, 756 validation_steps) 757 elif validation_split and 0. < validation_split < 1.: 758 (x, y, sample_weights, val_x, val_y, 759 val_sample_weights) = ( 760 training_utils_v1.split_training_and_validation_data( 761 x, y, sample_weights, validation_split)) 762 validation_data = (val_x, val_y, val_sample_weights) 763 else: 764 if validation_steps: 765 raise ValueError('`validation_steps` should not be specified if ' 766 '`validation_data` is None.') 767 768 return fit_generator( 769 model, (x, y, sample_weights), 770 steps_per_epoch=steps_per_epoch, 771 batch_size=batch_size, 772 epochs=epochs, 773 verbose=verbose, 774 callbacks=callbacks, 775 validation_data=validation_data, 776 validation_steps=validation_steps, 777 validation_freq=validation_freq, 778 workers=0, 779 shuffle=shuffle, 780 initial_epoch=initial_epoch, 781 steps_name='steps_per_epoch') 782 783 def evaluate(self, 784 model, 785 x=None, 786 y=None, 787 batch_size=None, 788 verbose=1, 789 sample_weight=None, 790 steps=None, 791 callbacks=None, 792 **kwargs): 793 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 794 x, y, sample_weights = model._standardize_user_data( 795 x, 796 y, 797 sample_weight=sample_weight, 798 batch_size=batch_size, 799 check_steps=True, 800 steps_name='steps', 801 steps=steps) 802 return evaluate_generator( 803 model, (x, y, sample_weights), 804 steps=steps, 805 batch_size=batch_size, 806 verbose=verbose, 807 workers=0, 808 callbacks=callbacks) 809 810 def predict(self, 811 model, 812 x, 813 batch_size=None, 814 verbose=0, 815 steps=None, 816 callbacks=None, 817 **kwargs): 818 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 819 x, _, _ = model._standardize_user_data( 820 x, check_steps=True, steps_name='steps', steps=steps) 821 return predict_generator( 822 model, 823 x, 824 steps=steps, 825 batch_size=batch_size, 826 verbose=verbose, 827 workers=0, 828 callbacks=callbacks) 829