1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Part of the Keras training engine related to distributed training.""" 16# pylint: disable=protected-access 17 18import numpy as np 19from tensorflow.python.distribute import distribution_strategy_context 20from tensorflow.python.distribute import input_lib 21from tensorflow.python.distribute import reduce_util as ds_reduce_util 22from tensorflow.python.eager import context 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.keras import backend 27from tensorflow.python.keras import callbacks as cbks 28from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc 29from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils 30from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util 31from tensorflow.python.keras.engine import training_arrays_v1 32from tensorflow.python.keras.engine import training_utils_v1 33from tensorflow.python.keras.utils.generic_utils import Progbar 34from tensorflow.python.keras.utils.mode_keys import ModeKeys 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.platform import tf_logging as logging 38 39 40def _per_replica_execution_function(model, mode): 41 exec_func = model._make_execution_function(mode) 42 return (exec_func.inputs, exec_func.outputs, exec_func.updates_op, 43 exec_func.session_kwargs) 44 45 46def _build_model(strategy, model, mode, inputs, targets=None): 47 if model._compile_distribution: 48 dist_utils.clone_model_on_replicas( 49 model, strategy, mode, inputs=inputs, targets=targets) 50 else: 51 dist_utils._build_distributed_network(model, strategy, mode, inputs, 52 targets) 53 54 55def _make_train_step_fn(model, mode, strategy, output_labels): 56 """Create step fn. 57 58 Args: 59 model: a Keras Model instance. 60 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 61 strategy: a `tf.distribute.Strategy` instance. 62 output_labels: the output labels for the step function. 63 64 Returns: 65 A step function to run by `tf.distribute.Strategy`. 66 """ 67 68 def _step_fn(ctx, inputs): 69 """A step fn that returns update ops.""" 70 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 71 inputs, targets = inputs 72 else: 73 targets = None 74 75 # When input feature is a dictionary of tensors, dictionary is flattended 76 # to an array and passed as a model input. This results in input mismatch 77 # when model input layer names are not sorted in alphabetical order as 78 # `nest.flatten()`sorts dictionary elements by keys. As so, transform input 79 # tensors into an array and order it along `model._feed_input_names`. 80 if isinstance(inputs, dict): 81 inputs = [inputs[input_name] for input_name in model._feed_input_names] 82 83 _build_model(strategy, model, mode, inputs, targets) 84 85 (grouped_inputs, grouped_outputs, grouped_updates, 86 grouped_session_args) = strategy.extended.call_for_each_replica( 87 _per_replica_execution_function, 88 args=(dist_utils.get_distributed_model(model, mode), mode)) 89 (all_inputs, all_outputs, all_updates, 90 all_session_args) = dist_utils.unwrap_values(strategy, grouped_inputs, 91 grouped_outputs, 92 grouped_updates, 93 grouped_session_args) 94 combined_fn = backend.function( 95 all_inputs, 96 all_outputs, 97 updates=all_updates, 98 name='distributed_' + str(mode) + '_function', 99 **all_session_args) 100 101 for label, output in zip(output_labels, combined_fn.outputs): 102 if label == 'loss': 103 reduce_op = ds_reduce_util.ReduceOp.SUM 104 else: 105 # We reduce all other metrics using mean for now. This is temporary 106 # workaround until new metrics are in place. 107 reduce_op = ds_reduce_util.ReduceOp.MEAN 108 ctx.set_last_step_output(label, output, reduce_op) 109 110 # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: 111 # feed_dict, session kwargs, run options, run_metadata for now. These should 112 # be handled appropriately 113 return combined_fn.updates_op 114 115 return _step_fn 116 117 118def experimental_tpu_fit_loop(model, 119 dataset, 120 epochs=100, 121 verbose=1, 122 callbacks=None, 123 initial_epoch=0, 124 steps_per_epoch=None, 125 val_dataset=None, 126 validation_steps=None, 127 validation_freq=1): 128 """Fit loop for training with TPU tf.distribute.Strategy. 129 130 Args: 131 model: Keras Model instance. 132 dataset: Dataset that returns inputs and targets 133 epochs: Number of times to iterate over the data 134 verbose: Integer, Verbosity mode, 0, 1 or 2 135 callbacks: List of callbacks to be called during training 136 initial_epoch: Epoch at which to start training 137 (useful for resuming a previous training run) 138 steps_per_epoch: Total number of steps (batches of samples) 139 before declaring one epoch finished and starting the 140 next epoch. Ignored with the default value of `None`. 141 val_dataset: Dataset for validation data. 142 validation_steps: Number of steps to run validation for 143 (only if doing validation from data tensors). 144 Ignored with the default value of `None`. 145 validation_freq: Only relevant if validation data is provided. Integer or 146 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 147 integer, specifies how many training epochs to run before a new 148 validation run is performed, e.g. `validation_freq=2` runs 149 validation every 2 epochs. If a Container, specifies the epochs on 150 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 151 validation at the end of the 1st, 2nd, and 10th epochs. 152 153 Returns: 154 Returns `None`. 155 156 Raises: 157 ValueError: in case of invalid arguments. 158 """ 159 mode = ModeKeys.TRAIN 160 161 current_strategy = model._distribution_strategy 162 iteration_value = min(steps_per_epoch, 163 current_strategy.extended.steps_per_run) 164 steps_per_run = backend.variable( 165 value=iteration_value, 166 dtype='int32', 167 name='steps_per_run') 168 169 # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops. 170 iterator = dist_utils.get_iterator(dataset, current_strategy) 171 172 scope = dist_utils.distributed_scope( 173 strategy=current_strategy, learning_phase=1) 174 scope.__enter__() 175 176 out_labels = model.metrics_names or [] 177 178 step_fn = _make_train_step_fn(model, ModeKeys.TRAIN, current_strategy, 179 out_labels) 180 181 # Add initial dummy values for loss and other metric tensors. 182 initial_loop_values = {} 183 initial_loop_values['loss'] = constant_op.constant(1e7) 184 for m in model._get_training_eval_metrics(): 185 tensor = m.result() 186 initial_loop_values[m.name] = array_ops.zeros(tensor.shape, tensor.dtype) 187 188 ctx = current_strategy.extended.experimental_run_steps_on_iterator( 189 step_fn, iterator, iterations=steps_per_run, 190 initial_loop_values=initial_loop_values) 191 train_op = ctx.run_op 192 output_tensors = ctx.last_step_outputs 193 194 do_validation = bool(validation_steps) 195 196 if model._compile_distribution: 197 dist_utils._copy_weights_to_distributed_model(model, mode) 198 199 callbacks = cbks.configure_callbacks( 200 callbacks, 201 model, 202 do_validation=do_validation, 203 epochs=epochs, 204 steps_per_epoch=steps_per_epoch, 205 verbose=verbose, 206 count_mode='steps', 207 mode=mode) 208 209 # Calculate the steps each time on the device. 210 steps_to_run = ([current_strategy.extended.steps_per_run] * 211 (steps_per_epoch // 212 current_strategy.extended.steps_per_run)) 213 if steps_per_epoch % current_strategy.extended.steps_per_run: 214 steps_to_run.append( 215 steps_per_epoch % current_strategy.extended.steps_per_run) 216 target_steps = len(steps_to_run) 217 218 callbacks._call_begin_hook(mode) 219 220 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 221 222 for epoch in range(initial_epoch, epochs): 223 dist_utils._reset_metrics(model) 224 callbacks.on_epoch_begin(epoch) 225 epoch_logs = {} 226 step_index = 0 227 prev_step_count = None 228 current_step = 0 229 while current_step < target_steps: 230 step_count = steps_to_run[current_step] 231 batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count} 232 callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs) 233 if prev_step_count is None or step_count != prev_step_count: 234 backend.get_session().run(steps_per_run.assign(step_count)) 235 prev_step_count = step_count 236 try: 237 _, outputs = backend.batch_get_value([train_op, output_tensors]) 238 except errors.OutOfRangeError: 239 logging.warning('Your dataset iterator ran out of data; ' 240 'interrupting training. Make sure that your dataset ' 241 'can generate at least `steps_per_epoch * epochs` ' 242 'batches (in this case, %d batches).' % 243 steps_per_epoch * epochs) 244 break 245 246 batch_logs.update(outputs) 247 callbacks._call_batch_hook(mode, 'end', step_index, batch_logs) 248 step_index = step_index + step_count 249 current_step += 1 250 251 if callbacks.model.stop_training: 252 break 253 254 if (do_validation and 255 training_utils_v1.should_run_validation(validation_freq, epoch)): 256 logging.info('Running validation at fit epoch: %s', epoch) 257 258 if model._compile_distribution: 259 # Since we create a new clone from the original model we need to copy 260 # the weights back to the original model before we can run validation. 261 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) 262 263 val_outs = experimental_tpu_test_loop( # pylint: disable=undefined-variable 264 model, 265 val_dataset, 266 steps=validation_steps, 267 verbose=verbose, 268 callbacks=callbacks) 269 if not isinstance(val_outs, list): 270 val_outs = [val_outs] 271 # Same labels assumed. 272 for label, val_out in zip(out_labels, val_outs): 273 epoch_logs['val_' + label] = val_out 274 275 callbacks.on_epoch_end(epoch, epoch_logs) 276 if callbacks.model.stop_training: 277 break 278 model._successful_loop_finish = True 279 callbacks._call_end_hook(mode) 280 281 if model._compile_distribution: 282 # Copy the weights back from the replicated model to the original model. 283 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) 284 scope.__exit__(None, None, None) 285 return model.history 286 287 288def experimental_tpu_test_loop(model, 289 dataset, 290 verbose=0, 291 steps=None, 292 callbacks=None): 293 """Test loop for evaluating with TPU tf.distribute.Strategy. 294 295 Args: 296 model: Keras Model instance. 297 dataset: Dataset for input data. 298 verbose: Integer, Verbosity mode 0 or 1. 299 steps: Total number of steps (batches of samples) 300 before declaring predictions finished. 301 Ignored with the default value of `None`. 302 callbacks: List of callbacks to be called during training 303 304 Returns: 305 Scalar loss (if the model has a single output and no metrics) 306 or list of scalars (if the model has multiple outputs 307 and/or metrics). The attribute `model.metrics_names` will give you 308 the display labels for the outputs. 309 """ 310 mode = ModeKeys.TEST 311 current_strategy = model._distribution_strategy 312 iterator = dist_utils.get_iterator(dataset, current_strategy) 313 314 scope = dist_utils.distributed_scope( 315 strategy=current_strategy, learning_phase=0) 316 scope.__enter__() 317 318 out_labels = model.metrics_names 319 320 def _test_step_fn(inputs): 321 """A fn that returns output of single test step.""" 322 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 323 inputs, targets = inputs 324 else: 325 targets = None 326 327 (distribution_strategy_context.get_replica_context().merge_call( 328 _build_model, args=(model, mode, inputs, targets))) 329 330 (_, outputs, updates, _) = _per_replica_execution_function( 331 dist_utils.get_distributed_model(model, mode), mode) 332 with ops.control_dependencies([updates]): 333 return [array_ops.identity(out) for out in outputs] 334 335 test_input_data = iterator.get_next() 336 per_replica_outputs = current_strategy.run( 337 _test_step_fn, args=(test_input_data,)) 338 output_tensors = {} 339 for label, output in zip(out_labels, per_replica_outputs): 340 if label == 'loss': 341 reduce_op = ds_reduce_util.ReduceOp.SUM 342 else: 343 # We reduce all other metrics using mean for now. This is temporary 344 # workaround until new metrics are in place. 345 reduce_op = ds_reduce_util.ReduceOp.MEAN 346 output_tensors[label] = current_strategy.reduce(reduce_op, output, 347 axis=None) 348 test_op = control_flow_ops.group(list(output_tensors.values())) 349 350 if verbose >= 1: 351 progbar = Progbar(target=steps) 352 353 if model._compile_distribution: 354 dist_utils._copy_weights_to_distributed_model(model, mode) 355 356 dist_utils._reset_metrics(model) 357 358 callbacks = cbks.configure_callbacks( 359 callbacks, 360 model, 361 do_validation=False, 362 epochs=1, 363 steps_per_epoch=steps, 364 verbose=verbose, 365 count_mode='steps', 366 mode=ModeKeys.TEST) 367 callbacks._call_begin_hook(mode) 368 369 outs = [0.] * len(model.metrics_names) 370 if steps is not None: 371 target_steps = steps 372 else: 373 raise ValueError('Number of steps could not be inferred from the data, ' 374 'please pass the steps argument.') 375 376 current_step = 0 377 while current_step < target_steps: 378 batch_logs = {'batch': current_step, 'size': 1} 379 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs) 380 try: 381 _, batch_outs = backend.batch_get_value([test_op, output_tensors]) 382 except errors.OutOfRangeError: 383 warning_msg = ( 384 'Make sure that your dataset can generate at least ' 385 '`steps` batches (in this case, {} batches).'.format(steps)) 386 387 logging.warning('Your dataset iterator ran out of data; ' 388 'interrupting evaluation. ' + warning_msg) 389 target_steps = current_step 390 break 391 for i, label in enumerate(model.metrics_names): 392 if i == 0: 393 # Loss is stateless metrics. 394 outs[i] += batch_outs[label] 395 else: 396 # For all stateful metrics, the aggregation is handled by mirrored vars. 397 outs[i] = batch_outs[label] 398 399 batch_logs = cbks.make_logs(model, batch_logs, outs, mode) 400 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs) 401 if verbose == 1: 402 progbar.update(current_step + 1) 403 current_step += 1 404 405 if verbose >= 1: 406 # Progress bar finishes at the end. 407 progbar.update(target_steps) 408 callbacks._call_end_hook(mode) 409 410 scope.__exit__(None, None, None) 411 if len(outs) >= 0: 412 outs[0] /= (target_steps) 413 414 if len(outs) == 1: 415 return outs[0] 416 return outs 417 418 419def experimental_tpu_predict_loop(model, 420 dataset, 421 verbose=0, 422 steps=None, 423 callbacks=None): 424 """Predict loop for predicting with TPU tf.distribute.Strategy. 425 426 Args: 427 model: Keras Model instance. 428 dataset: Dataset for input data. 429 verbose: Integer, Verbosity mode 0 or 1. 430 steps: Total number of steps (batches of samples) 431 before declaring `_predict_loop` finished. 432 Ignored with the default value of `None`. 433 callbacks: List of callbacks to be called during training 434 435 Returns: 436 Array of predictions (if the model has a single output) 437 or list of arrays of predictions 438 (if the model has multiple outputs). 439 """ 440 mode = ModeKeys.PREDICT 441 dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset) 442 padding_handler = None 443 if not dataset_fully_shaped: 444 # TODO(hongjunchoi): Investigate whether operations from 445 # PartialBatchPaddingHandler are unnecessarily pruned out 446 # during graph optimization. 447 padding_handler = padding_util.PartialBatchPaddingHandler( 448 model._feed_output_shapes) 449 batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset) 450 padding_handler.padded_batch_size = batch_size 451 padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask, 452 padding_handler.update_mask) 453 454 dataset = dataset.map(padding_handler.pad_batch) 455 dataset = dataset.unbatch() 456 # Upon this point, it is guaranteed that the dataset does not 457 # have partial batches. Thus, we set `drop_remainder=True` to 458 # get static shape information about the elements in the dataset. 459 dataset = dataset.batch(batch_size, drop_remainder=True) 460 461 if prefetch_buffer is not None: 462 dataset = dataset.prefetch(prefetch_buffer) 463 464 current_strategy = model._distribution_strategy 465 iterator = dist_utils.get_iterator(dataset, current_strategy) 466 467 scope = dist_utils.distributed_scope( 468 strategy=current_strategy, learning_phase=0) 469 scope.__enter__() 470 471 def _predict_step_fn(inputs): 472 """A fn that returns output of single prediction step.""" 473 474 (distribution_strategy_context.get_replica_context().merge_call( 475 _build_model, args=(model, mode, inputs))) 476 477 (_, outputs, updates, _) = _per_replica_execution_function( 478 dist_utils.get_distributed_model(model, mode), mode) 479 480 with ops.control_dependencies([updates]): 481 return [array_ops.identity(out) for out in outputs] 482 483 # TODO(hongjunchoi): When numpy array is passed as an input to `predict()` 484 # use numpy arrays directly to avoid cumulating unnecessary input pipeline 485 # ops. 486 predict_input_data = iterator.get_next() 487 per_replica_outputs = current_strategy.run( 488 _predict_step_fn, args=(predict_input_data,)) 489 output_tensors = dist_utils.flatten_per_replica_values( 490 current_strategy, per_replica_outputs) 491 492 if verbose >= 1: 493 progbar = Progbar(target=steps) 494 495 if model._compile_distribution: 496 dist_utils._copy_weights_to_distributed_model(model, mode) 497 498 dist_utils._reset_metrics(model) 499 500 callbacks = cbks.configure_callbacks( 501 callbacks, 502 model, 503 do_validation=False, 504 epochs=1, 505 steps_per_epoch=steps, 506 verbose=verbose, 507 count_mode='steps', 508 mode=mode) 509 callbacks._call_begin_hook(mode) 510 511 # Since we do not know how many samples we will see, we cannot pre-allocate 512 # the returned Numpy arrays. Instead, we store one array per batch seen 513 # and concatenate them upon returning. 514 num_model_outputs = len(model.output_names) 515 unconcatenated_outs = [[] for _ in range(num_model_outputs)] 516 if steps is not None: 517 target_steps = steps 518 else: 519 raise ValueError('Number of steps could not be inferred from the data, ' 520 'please pass the steps argument.') 521 522 current_step = 0 523 while current_step < target_steps: 524 batch_logs = {'batch': current_step, 'size': 1} 525 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs) 526 try: 527 predict_ops = control_flow_ops.group(output_tensors) 528 _, batch_outs = backend.batch_get_value([predict_ops, output_tensors]) 529 530 except errors.OutOfRangeError: 531 warning_msg = ( 532 'Make sure that your dataset can generate at least ' 533 '`steps` batches (in this case, {} batches).'.format(steps)) 534 535 logging.warning('Your dataset iterator ran out of data; ' 536 'interrupting evaluation. ' + warning_msg) 537 break 538 539 # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy. 540 for i in range(num_model_outputs): 541 output_start_index = i * current_strategy.num_replicas_in_sync 542 output_end_index = ( 543 output_start_index + current_strategy.num_replicas_in_sync) 544 single_model_output = batch_outs[output_start_index:output_end_index] 545 unconcatenated_outs[i].extend(single_model_output) 546 547 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 548 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs) 549 if verbose == 1: 550 progbar.update(current_step + 1) 551 current_step += 1 552 553 if verbose >= 1: 554 # Progress bar finishes at the end. 555 progbar.update(current_step) 556 557 callbacks._call_end_hook(mode) 558 559 scope.__exit__(None, None, None) 560 561 if len(unconcatenated_outs) == 1: 562 prediction_result = np.concatenate(unconcatenated_outs[0], axis=0) 563 else: 564 prediction_result = [ 565 np.concatenate(out, axis=0) for out in unconcatenated_outs 566 ] 567 568 if padding_handler: 569 prediction_result = padding_handler.apply_mask(prediction_result) 570 571 return prediction_result 572 573 574class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop): 575 """Training loop for distribution strategy with single worker.""" 576 577 def fit(self, 578 model, 579 x=None, 580 y=None, 581 batch_size=None, 582 epochs=1, 583 verbose=1, 584 callbacks=None, 585 validation_split=0., 586 validation_data=None, 587 shuffle=True, 588 class_weight=None, 589 sample_weight=None, 590 initial_epoch=0, 591 steps_per_epoch=None, 592 validation_steps=None, 593 validation_freq=1, 594 **kwargs): 595 """Fit loop for Distribution Strategies.""" 596 dist_utils.validate_callbacks(input_callbacks=callbacks, 597 optimizer=model.optimizer) 598 dist_utils.validate_inputs(x, y) 599 600 batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size( 601 model._distribution_strategy, 602 x, 603 batch_size, 604 steps_per_epoch, 605 ModeKeys.TRAIN, 606 validation_split=validation_split) 607 batch_size = model._validate_or_infer_batch_size( 608 batch_size, steps_per_epoch, x) 609 dataset = model._distribution_standardize_user_data( 610 x, y, 611 sample_weight=sample_weight, 612 class_weight=class_weight, 613 batch_size=batch_size, 614 validation_split=validation_split, 615 shuffle=shuffle, 616 epochs=epochs) 617 if not dist_utils.is_distributing_by_cloning(model): 618 with model._distribution_strategy.scope(): 619 (dataset, _, _) = model._standardize_user_data( 620 dataset, 621 sample_weight=sample_weight, 622 class_weight=class_weight, 623 batch_size=batch_size, 624 validation_split=validation_split, 625 shuffle=shuffle) 626 627 val_dataset = None 628 if validation_data: 629 val_x, val_y, val_sample_weights = ( 630 training_utils_v1.unpack_validation_data(validation_data)) 631 dist_utils.validate_inputs(val_x, val_y) 632 _, validation_steps = dist_utils.process_batch_and_step_size( 633 model._distribution_strategy, val_x, batch_size, validation_steps, 634 ModeKeys.TEST) 635 636 val_dataset = model._distribution_standardize_user_data( 637 val_x, val_y, 638 sample_weight=val_sample_weights, 639 class_weight=None, 640 batch_size=batch_size, 641 validation_split=validation_split, 642 shuffle=shuffle, 643 allow_partial_batch=True) 644 elif validation_split: 645 raise ValueError('validation_split argument is not supported with ' 646 'distribution strategies.') 647 648 if backend.is_tpu_strategy(model._distribution_strategy): 649 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 650 model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch') 651 if steps_per_epoch is None: 652 raise ValueError('Number of steps could not be inferred from the data, ' 653 'please pass the steps_per_epoch argument.') 654 655 if not context.executing_eagerly(): 656 # Run TPU training in a custom loop in graph mode. 657 return experimental_tpu_fit_loop( 658 model, 659 dataset, 660 epochs=epochs, 661 verbose=verbose, 662 callbacks=callbacks, 663 val_dataset=val_dataset, 664 initial_epoch=initial_epoch, 665 steps_per_epoch=steps_per_epoch, 666 validation_steps=validation_steps, 667 validation_freq=validation_freq) 668 669 return training_arrays_v1.fit_loop( 670 model, 671 dataset, 672 batch_size=batch_size, 673 epochs=epochs, 674 verbose=verbose, 675 callbacks=callbacks, 676 val_inputs=val_dataset, 677 shuffle=shuffle, 678 initial_epoch=initial_epoch, 679 steps_per_epoch=steps_per_epoch, 680 validation_steps=validation_steps, 681 validation_freq=validation_freq, 682 steps_name='steps_per_epoch') 683 684 def evaluate(self, 685 model, 686 x=None, 687 y=None, 688 batch_size=None, 689 verbose=1, 690 sample_weight=None, 691 steps=None, 692 callbacks=None, 693 **kwargs): 694 """Evaluate loop for Distribution Strategies.""" 695 dist_utils.validate_inputs(x, y) 696 batch_size, steps = dist_utils.process_batch_and_step_size( 697 model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST) 698 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 699 dataset = model._distribution_standardize_user_data( 700 x, y, 701 sample_weight=sample_weight, 702 batch_size=batch_size, 703 allow_partial_batch=True) 704 705 if backend.is_tpu_strategy(model._distribution_strategy): 706 steps = training_utils_v1.infer_steps_for_dataset( 707 model, dataset, steps, steps_name='steps') 708 if steps is None: 709 raise ValueError('Number of steps could not be inferred from the data, ' 710 'please pass the steps argument.') 711 712 if not context.executing_eagerly(): 713 # Run TPU evaluation in a custom loop in graph mode. 714 return experimental_tpu_test_loop( 715 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) 716 717 return training_arrays_v1.test_loop( 718 model, 719 inputs=dataset, 720 batch_size=batch_size, 721 verbose=verbose, 722 steps=steps, 723 callbacks=callbacks) 724 725 def predict(self, 726 model, 727 x, 728 batch_size=None, 729 verbose=0, 730 steps=None, 731 callbacks=None, 732 **kwargs): 733 """Predict loop for Distribution Strategies.""" 734 dist_utils.validate_inputs(x=x, y=None) 735 batch_size, steps = dist_utils.process_batch_and_step_size( 736 model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT) 737 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 738 dataset = model._distribution_standardize_user_data( 739 x, 740 batch_size=batch_size, 741 allow_partial_batch=True) 742 if backend.is_tpu_strategy(model._distribution_strategy): 743 steps = training_utils_v1.infer_steps_for_dataset( 744 model, dataset, steps, steps_name='steps') 745 if steps is None: 746 raise ValueError('Number of steps could not be inferred from the data, ' 747 'please pass the steps argument.') 748 if not context.executing_eagerly(): 749 return experimental_tpu_predict_loop( 750 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) 751 return training_arrays_v1.predict_loop( 752 model, 753 dataset, 754 batch_size=batch_size, 755 verbose=verbose, 756 steps=steps, 757 callbacks=callbacks) 758 759 760def _train_with_multi_worker(method): 761 """Decorator that handles multi worker training with distribution strategy.""" 762 763 def wrapper(model, **kwargs): 764 def _worker_fn(_): 765 callbacks = kwargs.pop('callbacks', None) 766 filtered_callbacks = dist_utils.filter_distributed_callbacks( 767 callbacks, model) 768 kwargs['callbacks'] = filtered_callbacks 769 return method(model, **kwargs) 770 771 return dc.run_distribute_coordinator( 772 _worker_fn, 773 model._distribution_strategy) 774 775 return wrapper 776 777 778class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop): 779 """Training loop for distribution strategy with multiple worker.""" 780 781 def __init__(self, single_worker_loop): 782 self._single_worker_loop = single_worker_loop 783 784 def fit(self, *args, **kwargs): 785 return _train_with_multi_worker(self._single_worker_loop.fit)( 786 *args, **kwargs) 787 788 def evaluate(self, *args, **kwargs): 789 return _train_with_multi_worker(self._single_worker_loop.evaluate)( 790 *args, **kwargs) 791 792 def predict(self, *args, **kwargs): 793 # Currently predict is still using the single worker implementation. 794 return self._single_worker_loop.predict(*args, **kwargs) 795