xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/training_distributed_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to distributed training."""
16# pylint: disable=protected-access
17
18import numpy as np
19from tensorflow.python.distribute import distribution_strategy_context
20from tensorflow.python.distribute import input_lib
21from tensorflow.python.distribute import reduce_util as ds_reduce_util
22from tensorflow.python.eager import context
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.keras import backend
27from tensorflow.python.keras import callbacks as cbks
28from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc
29from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
30from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
31from tensorflow.python.keras.engine import training_arrays_v1
32from tensorflow.python.keras.engine import training_utils_v1
33from tensorflow.python.keras.utils.generic_utils import Progbar
34from tensorflow.python.keras.utils.mode_keys import ModeKeys
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.platform import tf_logging as logging
38
39
40def _per_replica_execution_function(model, mode):
41  exec_func = model._make_execution_function(mode)
42  return (exec_func.inputs, exec_func.outputs, exec_func.updates_op,
43          exec_func.session_kwargs)
44
45
46def _build_model(strategy, model, mode, inputs, targets=None):
47  if model._compile_distribution:
48    dist_utils.clone_model_on_replicas(
49        model, strategy, mode, inputs=inputs, targets=targets)
50  else:
51    dist_utils._build_distributed_network(model, strategy, mode, inputs,
52                                          targets)
53
54
55def _make_train_step_fn(model, mode, strategy, output_labels):
56  """Create step fn.
57
58  Args:
59    model: a Keras Model instance.
60    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
61    strategy: a `tf.distribute.Strategy` instance.
62    output_labels: the output labels for the step function.
63
64  Returns:
65    A step function to run by `tf.distribute.Strategy`.
66  """
67
68  def _step_fn(ctx, inputs):
69    """A step fn that returns update ops."""
70    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
71      inputs, targets = inputs
72    else:
73      targets = None
74
75    # When input feature is a dictionary of tensors, dictionary is flattended
76    # to an array and passed as a model input. This results in input mismatch
77    # when model input layer names are not sorted in alphabetical order as
78    # `nest.flatten()`sorts dictionary elements by keys. As so, transform input
79    # tensors into an array and order it along `model._feed_input_names`.
80    if isinstance(inputs, dict):
81      inputs = [inputs[input_name] for input_name in model._feed_input_names]
82
83    _build_model(strategy, model, mode, inputs, targets)
84
85    (grouped_inputs, grouped_outputs, grouped_updates,
86     grouped_session_args) = strategy.extended.call_for_each_replica(
87         _per_replica_execution_function,
88         args=(dist_utils.get_distributed_model(model, mode), mode))
89    (all_inputs, all_outputs, all_updates,
90     all_session_args) = dist_utils.unwrap_values(strategy, grouped_inputs,
91                                                  grouped_outputs,
92                                                  grouped_updates,
93                                                  grouped_session_args)
94    combined_fn = backend.function(
95        all_inputs,
96        all_outputs,
97        updates=all_updates,
98        name='distributed_' + str(mode) + '_function',
99        **all_session_args)
100
101    for label, output in zip(output_labels, combined_fn.outputs):
102      if label == 'loss':
103        reduce_op = ds_reduce_util.ReduceOp.SUM
104      else:
105        # We reduce all other metrics using mean for now. This is temporary
106        # workaround until new metrics are in place.
107        reduce_op = ds_reduce_util.ReduceOp.MEAN
108      ctx.set_last_step_output(label, output, reduce_op)
109
110    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
111    # feed_dict, session kwargs, run options, run_metadata for now. These should
112    # be handled appropriately
113    return combined_fn.updates_op
114
115  return _step_fn
116
117
118def experimental_tpu_fit_loop(model,
119                              dataset,
120                              epochs=100,
121                              verbose=1,
122                              callbacks=None,
123                              initial_epoch=0,
124                              steps_per_epoch=None,
125                              val_dataset=None,
126                              validation_steps=None,
127                              validation_freq=1):
128  """Fit loop for training with TPU tf.distribute.Strategy.
129
130  Args:
131      model: Keras Model instance.
132      dataset: Dataset that returns inputs and targets
133      epochs: Number of times to iterate over the data
134      verbose: Integer, Verbosity mode, 0, 1 or 2
135      callbacks: List of callbacks to be called during training
136      initial_epoch: Epoch at which to start training
137          (useful for resuming a previous training run)
138      steps_per_epoch: Total number of steps (batches of samples)
139          before declaring one epoch finished and starting the
140          next epoch. Ignored with the default value of `None`.
141      val_dataset: Dataset for validation data.
142      validation_steps: Number of steps to run validation for
143          (only if doing validation from data tensors).
144          Ignored with the default value of `None`.
145      validation_freq: Only relevant if validation data is provided. Integer or
146          `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
147          integer, specifies how many training epochs to run before a new
148          validation run is performed, e.g. `validation_freq=2` runs
149          validation every 2 epochs. If a Container, specifies the epochs on
150          which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
151          validation at the end of the 1st, 2nd, and 10th epochs.
152
153  Returns:
154      Returns `None`.
155
156  Raises:
157      ValueError: in case of invalid arguments.
158  """
159  mode = ModeKeys.TRAIN
160
161  current_strategy = model._distribution_strategy
162  iteration_value = min(steps_per_epoch,
163                        current_strategy.extended.steps_per_run)
164  steps_per_run = backend.variable(
165      value=iteration_value,
166      dtype='int32',
167      name='steps_per_run')
168
169  # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
170  iterator = dist_utils.get_iterator(dataset, current_strategy)
171
172  scope = dist_utils.distributed_scope(
173      strategy=current_strategy, learning_phase=1)
174  scope.__enter__()
175
176  out_labels = model.metrics_names or []
177
178  step_fn = _make_train_step_fn(model, ModeKeys.TRAIN, current_strategy,
179                                out_labels)
180
181  # Add initial dummy values for loss and other metric tensors.
182  initial_loop_values = {}
183  initial_loop_values['loss'] = constant_op.constant(1e7)
184  for m in model._get_training_eval_metrics():
185    tensor = m.result()
186    initial_loop_values[m.name] = array_ops.zeros(tensor.shape, tensor.dtype)
187
188  ctx = current_strategy.extended.experimental_run_steps_on_iterator(
189      step_fn, iterator, iterations=steps_per_run,
190      initial_loop_values=initial_loop_values)
191  train_op = ctx.run_op
192  output_tensors = ctx.last_step_outputs
193
194  do_validation = bool(validation_steps)
195
196  if model._compile_distribution:
197    dist_utils._copy_weights_to_distributed_model(model, mode)
198
199  callbacks = cbks.configure_callbacks(
200      callbacks,
201      model,
202      do_validation=do_validation,
203      epochs=epochs,
204      steps_per_epoch=steps_per_epoch,
205      verbose=verbose,
206      count_mode='steps',
207      mode=mode)
208
209  # Calculate the steps each time on the device.
210  steps_to_run = ([current_strategy.extended.steps_per_run] *
211                  (steps_per_epoch //
212                   current_strategy.extended.steps_per_run))
213  if steps_per_epoch % current_strategy.extended.steps_per_run:
214    steps_to_run.append(
215        steps_per_epoch % current_strategy.extended.steps_per_run)
216  target_steps = len(steps_to_run)
217
218  callbacks._call_begin_hook(mode)
219
220  initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
221
222  for epoch in range(initial_epoch, epochs):
223    dist_utils._reset_metrics(model)
224    callbacks.on_epoch_begin(epoch)
225    epoch_logs = {}
226    step_index = 0
227    prev_step_count = None
228    current_step = 0
229    while current_step < target_steps:
230      step_count = steps_to_run[current_step]
231      batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
232      callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs)
233      if prev_step_count is None or step_count != prev_step_count:
234        backend.get_session().run(steps_per_run.assign(step_count))
235        prev_step_count = step_count
236      try:
237        _, outputs = backend.batch_get_value([train_op, output_tensors])
238      except errors.OutOfRangeError:
239        logging.warning('Your dataset iterator ran out of data; '
240                        'interrupting training. Make sure that your dataset '
241                        'can generate at least `steps_per_epoch * epochs` '
242                        'batches (in this case, %d batches).' %
243                        steps_per_epoch * epochs)
244        break
245
246      batch_logs.update(outputs)
247      callbacks._call_batch_hook(mode, 'end', step_index, batch_logs)
248      step_index = step_index + step_count
249      current_step += 1
250
251      if callbacks.model.stop_training:
252        break
253
254    if (do_validation and
255        training_utils_v1.should_run_validation(validation_freq, epoch)):
256      logging.info('Running validation at fit epoch: %s', epoch)
257
258      if model._compile_distribution:
259        # Since we create a new clone from the original model we need to copy
260        # the weights back to the original model before we can run validation.
261        dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
262
263      val_outs = experimental_tpu_test_loop(  # pylint: disable=undefined-variable
264          model,
265          val_dataset,
266          steps=validation_steps,
267          verbose=verbose,
268          callbacks=callbacks)
269      if not isinstance(val_outs, list):
270        val_outs = [val_outs]
271      # Same labels assumed.
272      for label, val_out in zip(out_labels, val_outs):
273        epoch_logs['val_' + label] = val_out
274
275    callbacks.on_epoch_end(epoch, epoch_logs)
276    if callbacks.model.stop_training:
277      break
278  model._successful_loop_finish = True
279  callbacks._call_end_hook(mode)
280
281  if model._compile_distribution:
282    # Copy the weights back from the replicated model to the original model.
283    dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
284  scope.__exit__(None, None, None)
285  return model.history
286
287
288def experimental_tpu_test_loop(model,
289                               dataset,
290                               verbose=0,
291                               steps=None,
292                               callbacks=None):
293  """Test loop for evaluating with TPU tf.distribute.Strategy.
294
295  Args:
296      model: Keras Model instance.
297      dataset: Dataset for input data.
298      verbose: Integer, Verbosity mode 0 or 1.
299      steps: Total number of steps (batches of samples)
300          before declaring predictions finished.
301          Ignored with the default value of `None`.
302      callbacks: List of callbacks to be called during training
303
304  Returns:
305      Scalar loss (if the model has a single output and no metrics)
306      or list of scalars (if the model has multiple outputs
307      and/or metrics). The attribute `model.metrics_names` will give you
308      the display labels for the outputs.
309  """
310  mode = ModeKeys.TEST
311  current_strategy = model._distribution_strategy
312  iterator = dist_utils.get_iterator(dataset, current_strategy)
313
314  scope = dist_utils.distributed_scope(
315      strategy=current_strategy, learning_phase=0)
316  scope.__enter__()
317
318  out_labels = model.metrics_names
319
320  def _test_step_fn(inputs):
321    """A fn that returns output of single test step."""
322    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
323      inputs, targets = inputs
324    else:
325      targets = None
326
327    (distribution_strategy_context.get_replica_context().merge_call(
328        _build_model, args=(model, mode, inputs, targets)))
329
330    (_, outputs, updates, _) = _per_replica_execution_function(
331        dist_utils.get_distributed_model(model, mode), mode)
332    with ops.control_dependencies([updates]):
333      return [array_ops.identity(out) for out in outputs]
334
335  test_input_data = iterator.get_next()
336  per_replica_outputs = current_strategy.run(
337      _test_step_fn, args=(test_input_data,))
338  output_tensors = {}
339  for label, output in zip(out_labels, per_replica_outputs):
340    if label == 'loss':
341      reduce_op = ds_reduce_util.ReduceOp.SUM
342    else:
343      # We reduce all other metrics using mean for now. This is temporary
344      # workaround until new metrics are in place.
345      reduce_op = ds_reduce_util.ReduceOp.MEAN
346    output_tensors[label] = current_strategy.reduce(reduce_op, output,
347                                                    axis=None)
348  test_op = control_flow_ops.group(list(output_tensors.values()))
349
350  if verbose >= 1:
351    progbar = Progbar(target=steps)
352
353  if model._compile_distribution:
354    dist_utils._copy_weights_to_distributed_model(model, mode)
355
356  dist_utils._reset_metrics(model)
357
358  callbacks = cbks.configure_callbacks(
359      callbacks,
360      model,
361      do_validation=False,
362      epochs=1,
363      steps_per_epoch=steps,
364      verbose=verbose,
365      count_mode='steps',
366      mode=ModeKeys.TEST)
367  callbacks._call_begin_hook(mode)
368
369  outs = [0.] * len(model.metrics_names)
370  if steps is not None:
371    target_steps = steps
372  else:
373    raise ValueError('Number of steps could not be inferred from the data, '
374                     'please pass the steps argument.')
375
376  current_step = 0
377  while current_step < target_steps:
378    batch_logs = {'batch': current_step, 'size': 1}
379    callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
380    try:
381      _, batch_outs = backend.batch_get_value([test_op, output_tensors])
382    except errors.OutOfRangeError:
383      warning_msg = (
384          'Make sure that your dataset can generate at least '
385          '`steps` batches (in this case, {} batches).'.format(steps))
386
387      logging.warning('Your dataset iterator ran out of data; '
388                      'interrupting evaluation. ' + warning_msg)
389      target_steps = current_step
390      break
391    for i, label in enumerate(model.metrics_names):
392      if i == 0:
393        # Loss is stateless metrics.
394        outs[i] += batch_outs[label]
395      else:
396        # For all stateful metrics, the aggregation is handled by mirrored vars.
397        outs[i] = batch_outs[label]
398
399    batch_logs = cbks.make_logs(model, batch_logs, outs, mode)
400    callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
401    if verbose == 1:
402      progbar.update(current_step + 1)
403    current_step += 1
404
405  if verbose >= 1:
406    # Progress bar finishes at the end.
407    progbar.update(target_steps)
408  callbacks._call_end_hook(mode)
409
410  scope.__exit__(None, None, None)
411  if len(outs) >= 0:
412    outs[0] /= (target_steps)
413
414  if len(outs) == 1:
415    return outs[0]
416  return outs
417
418
419def experimental_tpu_predict_loop(model,
420                                  dataset,
421                                  verbose=0,
422                                  steps=None,
423                                  callbacks=None):
424  """Predict loop for predicting with TPU tf.distribute.Strategy.
425
426  Args:
427      model: Keras Model instance.
428      dataset: Dataset for input data.
429      verbose: Integer, Verbosity mode 0 or 1.
430      steps: Total number of steps (batches of samples)
431          before declaring `_predict_loop` finished.
432          Ignored with the default value of `None`.
433      callbacks: List of callbacks to be called during training
434
435  Returns:
436      Array of predictions (if the model has a single output)
437      or list of arrays of predictions
438      (if the model has multiple outputs).
439  """
440  mode = ModeKeys.PREDICT
441  dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
442  padding_handler = None
443  if not dataset_fully_shaped:
444    # TODO(hongjunchoi): Investigate whether operations from
445    # PartialBatchPaddingHandler are unnecessarily pruned out
446    # during graph optimization.
447    padding_handler = padding_util.PartialBatchPaddingHandler(
448        model._feed_output_shapes)
449    batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset)
450    padding_handler.padded_batch_size = batch_size
451    padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask,
452                                                  padding_handler.update_mask)
453
454    dataset = dataset.map(padding_handler.pad_batch)
455    dataset = dataset.unbatch()
456    # Upon this point, it is guaranteed that the dataset does not
457    # have partial batches. Thus, we set `drop_remainder=True` to
458    # get static shape information about the elements in the dataset.
459    dataset = dataset.batch(batch_size, drop_remainder=True)
460
461    if prefetch_buffer is not None:
462      dataset = dataset.prefetch(prefetch_buffer)
463
464  current_strategy = model._distribution_strategy
465  iterator = dist_utils.get_iterator(dataset, current_strategy)
466
467  scope = dist_utils.distributed_scope(
468      strategy=current_strategy, learning_phase=0)
469  scope.__enter__()
470
471  def _predict_step_fn(inputs):
472    """A fn that returns output of single prediction step."""
473
474    (distribution_strategy_context.get_replica_context().merge_call(
475        _build_model, args=(model, mode, inputs)))
476
477    (_, outputs, updates, _) = _per_replica_execution_function(
478        dist_utils.get_distributed_model(model, mode), mode)
479
480    with ops.control_dependencies([updates]):
481      return [array_ops.identity(out) for out in outputs]
482
483  # TODO(hongjunchoi): When numpy array is passed as an input to `predict()`
484  # use numpy arrays directly to avoid cumulating unnecessary input pipeline
485  # ops.
486  predict_input_data = iterator.get_next()
487  per_replica_outputs = current_strategy.run(
488      _predict_step_fn, args=(predict_input_data,))
489  output_tensors = dist_utils.flatten_per_replica_values(
490      current_strategy, per_replica_outputs)
491
492  if verbose >= 1:
493    progbar = Progbar(target=steps)
494
495  if model._compile_distribution:
496    dist_utils._copy_weights_to_distributed_model(model, mode)
497
498  dist_utils._reset_metrics(model)
499
500  callbacks = cbks.configure_callbacks(
501      callbacks,
502      model,
503      do_validation=False,
504      epochs=1,
505      steps_per_epoch=steps,
506      verbose=verbose,
507      count_mode='steps',
508      mode=mode)
509  callbacks._call_begin_hook(mode)
510
511  # Since we do not know how many samples we will see, we cannot pre-allocate
512  # the returned Numpy arrays. Instead, we store one array per batch seen
513  # and concatenate them upon returning.
514  num_model_outputs = len(model.output_names)
515  unconcatenated_outs = [[] for _ in range(num_model_outputs)]
516  if steps is not None:
517    target_steps = steps
518  else:
519    raise ValueError('Number of steps could not be inferred from the data, '
520                     'please pass the steps argument.')
521
522  current_step = 0
523  while current_step < target_steps:
524    batch_logs = {'batch': current_step, 'size': 1}
525    callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
526    try:
527      predict_ops = control_flow_ops.group(output_tensors)
528      _, batch_outs = backend.batch_get_value([predict_ops, output_tensors])
529
530    except errors.OutOfRangeError:
531      warning_msg = (
532          'Make sure that your dataset can generate at least '
533          '`steps` batches (in this case, {} batches).'.format(steps))
534
535      logging.warning('Your dataset iterator ran out of data; '
536                      'interrupting evaluation. ' + warning_msg)
537      break
538
539    # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
540    for i in range(num_model_outputs):
541      output_start_index = i * current_strategy.num_replicas_in_sync
542      output_end_index = (
543          output_start_index + current_strategy.num_replicas_in_sync)
544      single_model_output = batch_outs[output_start_index:output_end_index]
545      unconcatenated_outs[i].extend(single_model_output)
546
547    batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
548    callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
549    if verbose == 1:
550      progbar.update(current_step + 1)
551    current_step += 1
552
553  if verbose >= 1:
554    # Progress bar finishes at the end.
555    progbar.update(current_step)
556
557  callbacks._call_end_hook(mode)
558
559  scope.__exit__(None, None, None)
560
561  if len(unconcatenated_outs) == 1:
562    prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
563  else:
564    prediction_result = [
565        np.concatenate(out, axis=0) for out in unconcatenated_outs
566    ]
567
568  if padding_handler:
569    prediction_result = padding_handler.apply_mask(prediction_result)
570
571  return prediction_result
572
573
574class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
575  """Training loop for distribution strategy with single worker."""
576
577  def fit(self,
578          model,
579          x=None,
580          y=None,
581          batch_size=None,
582          epochs=1,
583          verbose=1,
584          callbacks=None,
585          validation_split=0.,
586          validation_data=None,
587          shuffle=True,
588          class_weight=None,
589          sample_weight=None,
590          initial_epoch=0,
591          steps_per_epoch=None,
592          validation_steps=None,
593          validation_freq=1,
594          **kwargs):
595    """Fit loop for Distribution Strategies."""
596    dist_utils.validate_callbacks(input_callbacks=callbacks,
597                                  optimizer=model.optimizer)
598    dist_utils.validate_inputs(x, y)
599
600    batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
601        model._distribution_strategy,
602        x,
603        batch_size,
604        steps_per_epoch,
605        ModeKeys.TRAIN,
606        validation_split=validation_split)
607    batch_size = model._validate_or_infer_batch_size(
608        batch_size, steps_per_epoch, x)
609    dataset = model._distribution_standardize_user_data(
610        x, y,
611        sample_weight=sample_weight,
612        class_weight=class_weight,
613        batch_size=batch_size,
614        validation_split=validation_split,
615        shuffle=shuffle,
616        epochs=epochs)
617    if not dist_utils.is_distributing_by_cloning(model):
618      with model._distribution_strategy.scope():
619        (dataset, _, _) = model._standardize_user_data(
620            dataset,
621            sample_weight=sample_weight,
622            class_weight=class_weight,
623            batch_size=batch_size,
624            validation_split=validation_split,
625            shuffle=shuffle)
626
627    val_dataset = None
628    if validation_data:
629      val_x, val_y, val_sample_weights = (
630          training_utils_v1.unpack_validation_data(validation_data))
631      dist_utils.validate_inputs(val_x, val_y)
632      _, validation_steps = dist_utils.process_batch_and_step_size(
633          model._distribution_strategy, val_x, batch_size, validation_steps,
634          ModeKeys.TEST)
635
636      val_dataset = model._distribution_standardize_user_data(
637          val_x, val_y,
638          sample_weight=val_sample_weights,
639          class_weight=None,
640          batch_size=batch_size,
641          validation_split=validation_split,
642          shuffle=shuffle,
643          allow_partial_batch=True)
644    elif validation_split:
645      raise ValueError('validation_split argument is not supported with '
646                       'distribution strategies.')
647
648    if backend.is_tpu_strategy(model._distribution_strategy):
649      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
650          model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
651      if steps_per_epoch is None:
652        raise ValueError('Number of steps could not be inferred from the data, '
653                         'please pass the steps_per_epoch argument.')
654
655      if not context.executing_eagerly():
656        # Run TPU training in a custom loop in graph mode.
657        return experimental_tpu_fit_loop(
658            model,
659            dataset,
660            epochs=epochs,
661            verbose=verbose,
662            callbacks=callbacks,
663            val_dataset=val_dataset,
664            initial_epoch=initial_epoch,
665            steps_per_epoch=steps_per_epoch,
666            validation_steps=validation_steps,
667            validation_freq=validation_freq)
668
669    return training_arrays_v1.fit_loop(
670        model,
671        dataset,
672        batch_size=batch_size,
673        epochs=epochs,
674        verbose=verbose,
675        callbacks=callbacks,
676        val_inputs=val_dataset,
677        shuffle=shuffle,
678        initial_epoch=initial_epoch,
679        steps_per_epoch=steps_per_epoch,
680        validation_steps=validation_steps,
681        validation_freq=validation_freq,
682        steps_name='steps_per_epoch')
683
684  def evaluate(self,
685               model,
686               x=None,
687               y=None,
688               batch_size=None,
689               verbose=1,
690               sample_weight=None,
691               steps=None,
692               callbacks=None,
693               **kwargs):
694    """Evaluate loop for Distribution Strategies."""
695    dist_utils.validate_inputs(x, y)
696    batch_size, steps = dist_utils.process_batch_and_step_size(
697        model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST)
698    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
699    dataset = model._distribution_standardize_user_data(
700        x, y,
701        sample_weight=sample_weight,
702        batch_size=batch_size,
703        allow_partial_batch=True)
704
705    if backend.is_tpu_strategy(model._distribution_strategy):
706      steps = training_utils_v1.infer_steps_for_dataset(
707          model, dataset, steps, steps_name='steps')
708      if steps is None:
709        raise ValueError('Number of steps could not be inferred from the data, '
710                         'please pass the steps argument.')
711
712      if not context.executing_eagerly():
713        # Run TPU evaluation in a custom loop in graph mode.
714        return experimental_tpu_test_loop(
715            model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
716
717    return training_arrays_v1.test_loop(
718        model,
719        inputs=dataset,
720        batch_size=batch_size,
721        verbose=verbose,
722        steps=steps,
723        callbacks=callbacks)
724
725  def predict(self,
726              model,
727              x,
728              batch_size=None,
729              verbose=0,
730              steps=None,
731              callbacks=None,
732              **kwargs):
733    """Predict loop for Distribution Strategies."""
734    dist_utils.validate_inputs(x=x, y=None)
735    batch_size, steps = dist_utils.process_batch_and_step_size(
736        model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT)
737    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
738    dataset = model._distribution_standardize_user_data(
739        x,
740        batch_size=batch_size,
741        allow_partial_batch=True)
742    if backend.is_tpu_strategy(model._distribution_strategy):
743      steps = training_utils_v1.infer_steps_for_dataset(
744          model, dataset, steps, steps_name='steps')
745      if steps is None:
746        raise ValueError('Number of steps could not be inferred from the data, '
747                         'please pass the steps argument.')
748      if not context.executing_eagerly():
749        return experimental_tpu_predict_loop(
750            model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
751    return training_arrays_v1.predict_loop(
752        model,
753        dataset,
754        batch_size=batch_size,
755        verbose=verbose,
756        steps=steps,
757        callbacks=callbacks)
758
759
760def _train_with_multi_worker(method):
761  """Decorator that handles multi worker training with distribution strategy."""
762
763  def wrapper(model, **kwargs):
764    def _worker_fn(_):
765      callbacks = kwargs.pop('callbacks', None)
766      filtered_callbacks = dist_utils.filter_distributed_callbacks(
767          callbacks, model)
768      kwargs['callbacks'] = filtered_callbacks
769      return method(model, **kwargs)
770
771    return dc.run_distribute_coordinator(
772        _worker_fn,
773        model._distribution_strategy)
774
775  return wrapper
776
777
778class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
779  """Training loop for distribution strategy with multiple worker."""
780
781  def __init__(self, single_worker_loop):
782    self._single_worker_loop = single_worker_loop
783
784  def fit(self, *args, **kwargs):
785    return _train_with_multi_worker(self._single_worker_loop.fit)(
786        *args, **kwargs)
787
788  def evaluate(self, *args, **kwargs):
789    return _train_with_multi_worker(self._single_worker_loop.evaluate)(
790        *args, **kwargs)
791
792  def predict(self, *args, **kwargs):
793    # Currently predict is still using the single worker implementation.
794    return self._single_worker_loop.predict(*args, **kwargs)
795