xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/distribute/distributed_training_utils_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
16# pylint:disable=protected-access
17
18import functools
19
20import numpy as np
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.ops import iterator_ops
24from tensorflow.python.distribute import reduce_util
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.keras import backend
32from tensorflow.python.keras import callbacks
33from tensorflow.python.keras import metrics as metrics_module
34from tensorflow.python.keras import optimizers
35from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc
36from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
37from tensorflow.python.keras.engine import training_utils_v1
38from tensorflow.python.keras.optimizer_v2 import optimizer_v2
39from tensorflow.python.keras.utils import tf_contextlib
40from tensorflow.python.keras.utils.mode_keys import ModeKeys
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import control_flow_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import sparse_ops
45from tensorflow.python.ops import variables
46from tensorflow.python.ops.ragged import ragged_tensor
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import nest
49
50
51def set_weights(distribution_strategy, dist_model, weights):
52  """Sets the weights of the replicated models.
53
54  The weights of the replicated models are set to the weights of the original
55  model. The weights of the replicated model are Mirrored variables and hence
56  we need to use the `update` call within a DistributionStrategy scope.
57
58  Args:
59    distribution_strategy: DistributionStrategy used to distribute training
60        and validation.
61    dist_model: The replicated models on the different devices.
62    weights: The weights of the original model.
63  """
64  assign_ops = []
65  for layer in dist_model.layers:
66    num_param = len(layer.weights)
67    layer_weights = weights[:num_param]
68    for sw, w in zip(layer.weights, layer_weights):
69      if ops.executing_eagerly_outside_functions():
70        sw.assign(w)
71      else:
72        assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
73    weights = weights[num_param:]
74
75  if not ops.executing_eagerly_outside_functions():
76    backend.get_session(assign_ops).run(assign_ops)
77
78
79def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
80                  grouped_updates=None, grouped_session_args=None,
81                  with_loss_tensor=False):
82  """Unwrap the list of values contained in the PerReplica parameters.
83
84  This function calls `flatten_per_replica_values` to parse each of the input
85  parameters into a list of values on the different devices. If we set
86  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
87  the different devices to give us one loss tensor.
88
89  Args:
90    distribution_strategy: DistributionStrategy used to distribute training and
91        validation.
92    grouped_inputs: PerReplica inputs returned from the train or test function
93        that we ran on each device.
94    grouped_outputs: PerReplica outputs returned from the train or test function
95        that we ran on each device.
96    grouped_updates: PerReplica updates returned from the train or test function
97        that we ran on each device.
98    grouped_session_args: PerReplica session args returned from the train or
99        test function that we ran on each device.
100    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
101        tensor as one of the outputs.
102
103  Returns:
104    Values of each of the PerReplica parameters.
105
106  """
107  # Unwrap per device values returned from each model's train function.
108  # This will be used to construct the main train function.
109  all_inputs = flatten_per_replica_values(distribution_strategy,
110                                          grouped_inputs)
111  all_outputs = unwrap_outputs(distribution_strategy, grouped_outputs,
112                               with_loss_tensor)
113
114  if grouped_updates:
115    all_updates = flatten_per_replica_values(distribution_strategy,
116                                             grouped_updates)
117  else:
118    all_updates = None
119
120  all_session_args = {}
121  if grouped_session_args:
122    grouped_feed_dict = grouped_session_args.get('feed_dict')
123    if grouped_feed_dict:
124      all_session_args['feed_dict'] = flatten_per_replica_values(
125          distribution_strategy, grouped_feed_dict)
126
127    grouped_fetches = grouped_session_args.get('fetches')
128    if grouped_fetches:
129      all_session_args['fetches'] = flatten_per_replica_values(
130          distribution_strategy, grouped_fetches)
131
132  # TODO(priyag): Return only non empty/None values
133  return all_inputs, all_outputs, all_updates, all_session_args
134
135
136def unwrap_output_dict(strategy, grouped_outputs, mode):
137  """Unwrap the list of outputs contained in the PerReplica parameters."""
138  if mode == ModeKeys.PREDICT:
139    return flatten_per_replica_values(strategy, grouped_outputs)
140
141  # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict,
142  # the output is as same structure as model output. They need to be treated
143  # differently
144  total_loss = strategy.reduce(reduce_util.ReduceOp.SUM,
145                               grouped_outputs['total_loss'][0], axis=None)
146  output_losses = flatten_per_replica_values(strategy,
147                                             grouped_outputs['output_losses'])
148  metrics = flatten_per_replica_values(strategy,
149                                       grouped_outputs['metrics'])
150  batch_size = strategy.reduce(reduce_util.ReduceOp.SUM,
151                               grouped_outputs['batch_size'], axis=None)
152  if (backend.is_tpu_strategy(strategy) and
153      ops.executing_eagerly_outside_functions()):
154    # Choose 1 value per replica in the TPU case since all replicas produce the
155    # same output.
156    # We only do this in eager mode for now since this function is used in
157    # both graph and eager mode and in the graph case we currently don't use
158    # experimental_run so would need to be removed when we converge the graph
159    # code path as well.
160    output_losses = output_losses[::strategy.num_replicas_in_sync]
161    metrics = metrics[::strategy.num_replicas_in_sync]
162  return {'total_loss': [total_loss],
163          'output_losses': output_losses,
164          'metrics': metrics,
165          'batch_size': batch_size}
166
167
168def unwrap_outputs(distribution_strategy, grouped_outputs,
169                   with_loss_tensor=False):
170  """Unwrap the list of outputs contained in the PerReplica parameters.
171
172  This function calls `flatten_per_replica_values` to parse each of the input
173  parameters into a list of outputs on the different devices. If we set
174  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
175  the different devices to give us one loss tensor.
176
177  Args:
178    distribution_strategy: DistributionStrategy used to distribute training and
179        validation.
180    grouped_outputs: PerReplica outputs returned from the train or test function
181        that we ran on each device.
182    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
183        tensor as one of the outputs.
184
185  Returns:
186    Values of each of the PerReplica outputs.
187
188  """
189  if not with_loss_tensor:
190    return flatten_per_replica_values(distribution_strategy,
191                                      grouped_outputs)
192
193  if not isinstance(grouped_outputs, list):
194    grouped_outputs = [grouped_outputs]
195  # reduce loss tensor before adding it to the list of fetches
196  loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
197                                      grouped_outputs[0], axis=None)
198  all_outputs = flatten_per_replica_values(distribution_strategy,
199                                           grouped_outputs[1:])
200  if (backend.is_tpu_strategy(distribution_strategy) and
201      ops.executing_eagerly_outside_functions()):
202    # Choose 1 value per replica in the TPU case since all replicas produce the
203    # same output.
204    # We only do this in eager mode for now since this function is used in
205    # both graph and eager mode and in the graph case we currently don't use
206    # experimental_run so would need to be removed when we converge the graph
207    # code path as well.
208    all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
209  return [loss] + all_outputs
210
211
212def flatten_per_replica_values(distribution_strategy, per_replica_values):
213  """Unwraps and flattens a nest of PerReplica parameters.
214
215  PerReplica values have one value associated with each device. Each entry in
216  the PerReplica dict has a device `key` and the corresponding value on the
217  device as the `value`. In this function we take a PerReplica value or a list
218  of PerReplica values and return all the values in the PerReplica dict.
219
220  Args:
221    distribution_strategy: DistributionStrategy used to distribute training and
222      validation.
223    per_replica_values: List of PerReplica object or a single PerReplica object.
224
225  Returns:
226    List of values of all the PerReplica objects.
227
228  """
229  # pylint: disable=g-complex-comprehension
230  # This function takes a PerReplica object or a list of PerReplica objects and
231  # returns all the values associated with it.
232  return [e for flattened in nest.flatten(per_replica_values)
233          for e in distribution_strategy.unwrap(flattened)]
234
235
236def validate_callbacks(input_callbacks, optimizer):
237  """Validate whether given callbacks are supported by DistributionStrategy.
238
239  Args:
240    input_callbacks: List of callbacks passed by the user to fit.
241    optimizer: Optimizer instance used to train the model.
242
243  Raises:
244    ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
245        callbacks passed.
246    ValueError: If `write_grads` is one of the parameters passed as part of the
247        TensorBoard callback.
248  """
249  if input_callbacks:
250    for callback in input_callbacks:
251      if isinstance(callback, (callbacks.LearningRateScheduler,
252                               callbacks.ReduceLROnPlateau)):
253
254        if not isinstance(optimizer, optimizer_v2.OptimizerV2):
255          raise ValueError('You must specify a Keras Optimizer V2 when using '
256                           '%s callback with DistributionStrategy.' % callback)
257
258      # If users want to use the TensorBoard callback they cannot use certain
259      # features of the callback that involve accessing model attributes and
260      # running ops.
261      if isinstance(callback, callbacks.TensorBoard):
262        if getattr(callback, 'write_grads', False):
263          logging.warning(
264              UserWarning(
265                  '`write_grads` in the TensorBoard callback is not supported '
266                  'when using DistributionStrategy. Setting `write_grads` '
267                  'to `False`.'))
268          callback.write_grads = False
269
270
271def validate_distributed_dataset_inputs(distribution_strategy, x, y,
272                                        sample_weights=None):
273  """Validate all the components of a DistributedValue Dataset input.
274
275  Args:
276    distribution_strategy: The current DistributionStrategy used to call
277        `fit`/`evaluate`.
278    x: Input Dataset DistributedValue object. For example, when we use
279        `MirroredStrategy` this is a PerReplica object with a tensor for each
280        device set in the dict. x can also be a tuple or dict. The keys of the
281        dict should match the names of the input layers of the model.
282    y: Target Dataset DistributedValue object. For example, when we use
283        `MirroredStrategy` this is a PerReplica object with a tensor for each
284        device set in the dict. y can also be a tuple or dict. The keys of the
285        dict should match the names of the output layers of the model.
286    sample_weights: Sample weights Dataset DistributedValue object. For example,
287        when we use `MirroredStrategy` this is a PerReplica object with a tensor
288        for each device set in the dict.
289
290  Returns:
291    The unwrapped values list of the x and y DistributedValues inputs.
292
293  Raises:
294    ValueError: If x and y do not have support for being evaluated as tensors.
295        or if x and y contain elements that are not tensors or if x and y
296        contain elements that have a shape or dtype mismatch.
297  """
298  # If the input and target used to call the model are not dataset tensors,
299  # we need to raise an error. When using a DistributionStrategy, the input
300  # and targets to a model should be from a `tf.data.Dataset`.
301
302  # If each element of x and y are not tensors, we cannot standardize and
303  # validate the input and targets.
304  x_values_list = validate_per_replica_inputs(distribution_strategy, x)
305
306  if y is not None:
307    y_values_list = validate_per_replica_inputs(distribution_strategy, y)
308  else:
309    y_values_list = None
310
311  if sample_weights is not None:
312    sample_weights_list = validate_per_replica_inputs(distribution_strategy,
313                                                      sample_weights)
314  else:
315    sample_weights_list = None
316
317  # Return the unwrapped values to avoid calling `unwrap` a second time.
318  return x_values_list, y_values_list, sample_weights_list
319
320
321def validate_per_replica_inputs(distribution_strategy, x):
322  """Validates PerReplica dataset input list.
323
324  Args:
325    distribution_strategy: The current DistributionStrategy used to call
326      `fit`, `evaluate` and `predict`.
327    x: A list of PerReplica objects that represent the input or
328      target values.
329
330  Returns:
331    List containing the first element of each of the PerReplica objects in
332    the input list.
333
334  Raises:
335    ValueError: If any of the objects in the `per_replica_list` is not a tensor.
336
337  """
338  # Convert the inputs and targets into a list of PerReplica objects.
339  per_replica_list = nest.flatten(x, expand_composites=True)
340  x_values_list = []
341  for x in per_replica_list:
342    # At this point x should contain only tensors.
343    x_values = distribution_strategy.unwrap(x)
344    for value in x_values:
345      if not tensor_util.is_tf_type(value):
346        raise ValueError('Dataset input to the model should be tensors instead '
347                         'they are of type {}'.format(type(value)))
348
349    if not context.executing_eagerly():
350      # Validate that the shape and dtype of all the elements in x are the same.
351      validate_all_tensor_shapes(x, x_values)
352    validate_all_tensor_types(x, x_values)
353
354    x_values_list.append(x_values[0])
355  return x_values_list
356
357
358def validate_all_tensor_types(x, x_values):
359  x_dtype = x_values[0].dtype
360  for i in range(1, len(x_values)):
361    if x_dtype != x_values[i].dtype:
362      raise ValueError('Input tensor dtypes do not match for distributed tensor'
363                       ' inputs {}'.format(x))
364
365
366def validate_all_tensor_shapes(x, x_values):
367  # Validate that the shape of all the elements in x have the same shape
368  x_shape = x_values[0].shape.as_list()
369  for i in range(1, len(x_values)):
370    if x_shape != x_values[i].shape.as_list():
371      raise ValueError('Input tensor shapes do not match for distributed tensor'
372                       ' inputs {}'.format(x))
373
374
375def _wait_for_variable_initialization(session):
376  """Utility to wait for variables to be initialized."""
377  all_variables = backend._get_variables(backend.get_graph())  # pylint: disable=protected-access
378  candidate_vars = []
379  for v in all_variables:
380    if not getattr(v, '_keras_initialized', False):
381      candidate_vars.append(v)
382
383  if not candidate_vars:
384    return
385
386  while True:
387    is_initialized = session.run(
388        [variables.is_variable_initialized(v) for v in candidate_vars])
389    uninitialized_vars = []
390    for flag, v in zip(is_initialized, candidate_vars):
391      if not flag:
392        uninitialized_vars.append(v)
393      v._keras_initialized = True  # pylint: disable=protected-access
394    if not uninitialized_vars:
395      break
396
397
398def init_restore_or_wait_for_variables():
399  """Initialize or restore variables or wait for variables to be initialized."""
400  backend._initialize_variables(backend._get_session())  # pylint: disable=protected-access
401
402
403def validate_inputs(x, y):
404  """Validate inputs when using DistributionStrategy.
405
406  Args:
407    x: Model Inputs.
408    y: Model Targets.
409
410  Raises:
411    ValueError: if input is not a Dataset or a numpy array(when we use
412      MirroredStrategy).
413  """
414  if (isinstance(x, iterator_ops.Iterator) or
415      isinstance(y, iterator_ops.Iterator)):
416    raise ValueError('`DistributionStrategy` does not support inputs of type '
417                     'Iterator. You must pass a `tf.data.Dataset` object or a '
418                     'numpy array as input.')
419
420
421def is_dataset_shape_fully_defined(dataset):
422  """Returns whether a dataset contains a final partial batch."""
423  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
424  unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
425  return not unknown_shapes
426
427
428def process_batch_and_step_size(strategy,
429                                inputs,
430                                batch_size,
431                                steps_per_epoch,
432                                mode,
433                                validation_split=0.):
434  """Process the batch size and step size based on input and dist strategy."""
435  first_x_value = nest.flatten(inputs)[0]
436  if isinstance(first_x_value, np.ndarray):
437    num_samples = first_x_value.shape[0]
438    if validation_split and 0. < validation_split < 1.:
439      num_samples = int(num_samples * (1 - validation_split))
440    # Until support for partial batch is implemented across all
441    # functions and distribution strategy, we pass `mode` to selectively
442    # relax the constraint to consume all the training samples.
443    steps_per_epoch, batch_size = get_input_params(
444        strategy, num_samples, steps_per_epoch, batch_size, mode=mode)
445  return batch_size, steps_per_epoch
446
447
448def get_input_params(distribution_strategy,
449                     num_samples,
450                     steps,
451                     batch_size,
452                     mode=None):
453  """Calculate the number of batches and steps/steps_per_epoch.
454
455  Args:
456    distribution_strategy: The DistributionStrategy used to compile the model.
457    num_samples: The number of samples from which we determine the batch size
458      and steps.
459    steps:  The specified number of steps.
460    batch_size: The specified batch_size.
461    mode: ModeKey representing whether input will be used for training,
462      evaluation, or prediction. This is used to relax the constraints on
463      consuming all the training samples to keep compatibility till we support
464      partial batches. If none, then partial batches are not allowed.
465
466  Returns:
467    steps: The steps or steps_per_epoch argument depending on if a user is
468        calling `fit`, `evaluate` or `predict`. If the is_training flag is set
469        we don't require the number of samples to be used completely.
470    batch_size: The batch size to be used in model iterations.
471
472  Raises:
473    ValueError: If the number of batches or steps evaluates to 0.
474
475  """
476  # TODO(b/118776054): Use global batch size for Keras/DS support.
477  # Currently this is only supported in TPUStrategy and CoreMirroredStrategy.
478  use_per_replica_batch = not dist_utils.global_batch_size_supported(
479      distribution_strategy)
480
481  # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except for
482  # `fit()` on TPUStrategy.
483  # In graph mode, the zero batch case in batch norm is not handled due to
484  # XLA-GPU regression. Uneven batch sizes are not allowed except
485  # for `test()` and `predict()` on TPUStrategy.
486  if context.executing_eagerly():
487    allow_partial_batch = (
488        mode != ModeKeys.TRAIN or
489        not backend.is_tpu_strategy(distribution_strategy))
490  else:
491    allow_partial_batch = (
492        mode == ModeKeys.TRAIN or
493        ((mode == ModeKeys.PREDICT or mode == ModeKeys.TEST) and
494         backend.is_tpu_strategy(distribution_strategy)))
495
496  if steps is None:
497    if batch_size is None:
498      # If neither the batch size or number of steps are set. We choose the
499      # global batch size as the minimum of number of samples and 32. 32 is
500      # chosen to provide backward compatibility.
501      global_batch_size = min(num_samples, 32)
502    else:
503      # If the user provided the batch size we need to handle the case
504      # between different strategies that use the global/per-replica batch size
505      global_batch_size = batch_size
506      if use_per_replica_batch:
507        global_batch_size *= distribution_strategy.num_replicas_in_sync
508    if allow_partial_batch:
509      steps = np.ceil(num_samples / global_batch_size).astype(int)
510    else:
511      if num_samples % global_batch_size:
512        raise ValueError('The number of samples %s is not divisible by '
513                         'batch size %s.' % (num_samples, global_batch_size))
514      steps = num_samples // global_batch_size
515  else:
516    if batch_size is None:
517      # We calculate the batch size based on the number of steps specified
518      if num_samples % steps:
519        raise ValueError('The number of samples %s is not divisible by '
520                         'steps %s. Please change the number of steps to a '
521                         'value that can consume all the samples' % (
522                             num_samples, steps))
523      global_batch_size = num_samples // steps
524    else:
525      # If the user provided the batch size we need to handle the case
526      # between different strategies that use the global/per-replica batch size
527      global_batch_size = batch_size
528      if use_per_replica_batch:
529        global_batch_size *= distribution_strategy.num_replicas_in_sync
530
531      min_num_samples = global_batch_size * steps
532      if allow_partial_batch:
533        min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0
534
535      if num_samples < min_num_samples:
536        raise ValueError('Number of samples %s is less than samples required '
537                         'for specified batch_size %s and steps %s' % (
538                             num_samples, global_batch_size, steps))
539
540  # We need to return the per replica or global batch size based on the strategy
541  if use_per_replica_batch:
542    if global_batch_size % distribution_strategy.num_replicas_in_sync:
543      raise ValueError(
544          'The batch size (%s) could not be sharded evenly across the sync '
545          'replicas (%s) in the distribution strategy.' % (
546              global_batch_size, distribution_strategy.num_replicas_in_sync))
547    batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync
548  else:
549    batch_size = global_batch_size
550
551  return steps, batch_size
552
553
554def get_batch_dimension(iterator):
555  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator))
556  # Take the batch size from the first element, as it should be the same for
557  # all.
558  dims = shapes[0].dims
559  return dims[0] if dims else None
560
561
562def get_iterator(dataset, distribution_strategy):
563  with distribution_strategy.scope():
564    iterator = distribution_strategy.make_dataset_iterator(dataset)
565  initialize_iterator(iterator, distribution_strategy)
566  return iterator
567
568
569def initialize_iterator(iterator, distribution_strategy):
570  with distribution_strategy.scope():
571    init_op = control_flow_ops.group(iterator.initializer)
572    if not context.executing_eagerly():
573      backend.get_session((init_op,)).run(init_op)
574
575
576def _get_input_from_iterator(iterator, model):
577  """Get elements from the iterator and verify the input shape and type."""
578  next_element = iterator.get_next()
579
580  # `len(nest.flatten(x))` is going to not count empty elements such as {}.
581  # len(nest.flatten([[0,1,2], {}])) is 3 and not 4.   The `next_element` is
582  # going to get flattened in `_prepare_feed_values` to work around that. Empty
583  # elements are going to get filtered out as part of the flattening.
584  if len(nest.flatten(next_element)) == len(model.inputs):
585    x = next_element
586    y = None
587    sample_weights = None
588  elif len(nest.flatten(next_element)) == (len(model.inputs) +
589                                           len(model.outputs)):
590    x, y = next_element
591    sample_weights = None
592  else:
593    x, y, sample_weights = next_element
594
595  # Validate that all the elements in x and y are of the same type and shape.
596  validate_distributed_dataset_inputs(
597      model._distribution_strategy, x, y, sample_weights)
598  return x, y, sample_weights
599
600
601def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
602  """Prepare feed values to the model execution function.
603
604  Args:
605    model: Model to prepare feed values for.
606    inputs: List or dict of model inputs.
607    targets: Optional list of model targets.
608    sample_weights: Optional list of sample weight arrays.
609    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
610
611  Returns:
612    Feed values for the model in the given mode.
613  """
614  strategy = model._distribution_strategy
615  inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
616  if backend.is_tpu_strategy(strategy):
617    if sample_weights is not None:
618      raise ValueError('TPUStrategy does not support sample weights.')
619
620  # When the inputs are dict, then we want to flatten it in the same order as
621  # the input layers, such that the data are fed into the input layers in the
622  # correct order.
623  if isinstance(inputs, dict):
624    inputs = [inputs[key] for key in model._feed_input_names]
625  if is_distributing_by_cloning(model):
626    inputs = flatten_per_replica_values(strategy, inputs)
627    targets = flatten_per_replica_values(strategy, targets)
628    # Expand 1-dimensional inputs.
629    # TODO(b/124535720): Remove once this standarize data logic is shared with
630    # main flow.
631    inputs, targets = nest.map_structure(
632        training_utils_v1.standardize_single_array, (inputs, targets))
633  else:
634    inputs = training_utils_v1.ModelInputs(inputs).as_list()
635
636  if mode == ModeKeys.PREDICT:
637    sample_weights = []
638    targets = []
639  elif sample_weights is not None and is_distributing_by_cloning(model):
640    if context.executing_eagerly() and not model._compile_distribution:
641      raise NotImplementedError('`sample_weight` is not supported when using '
642                                'tf.distribute.Strategy in eager mode and '
643                                'cloning=True.')
644    sample_weights = flatten_per_replica_values(strategy, sample_weights)
645
646  ins = [inputs, targets, sample_weights]
647  return tuple(ins)
648
649
650def is_distributing_by_cloning(model):
651  """Decide whether this model is going to be distributed via cloning.
652
653  We are going to distribute the model by cloning in graph mode.
654
655  Args:
656    model: Keras model to distribute.
657
658  Returns:
659    True if the `model` is going to be distributed using cloning and False
660    otherwise.
661  """
662  if (backend.is_tpu_strategy(model._distribution_strategy) and
663      context.executing_eagerly):  # b/137580852
664    return False
665  elif ops.executing_eagerly_outside_functions():
666    return bool(model._compile_distribution)
667  return True
668
669
670def _custom_compile_for_predict(model):
671  """Custom compile for TPU predict mode."""
672  if not model.built:
673    # Model is not compilable because it does not know its number of inputs
674    # and outputs, nor their shapes and names. We will compile after the first
675    # time the model gets called on training data.
676    return
677  model._is_compiled = True
678  model.total_loss = None
679  model.train_function = None
680  model.test_function = None
681  model.predict_function = None
682
683
684def _build_network_on_replica(model, mode, inputs=None, targets=None):
685  """Build an updated model on replicas.
686
687  We create a new Keras model while sharing the variables from the old graph.
688  Building a new sub-graph is required since the original keras model creates
689  placeholders for the input and the output that are not accessible till we
690  call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.
691
692  The sharing of weights and layers between the old and the new model guarantee
693  that we're using Strategy variables and any updates on either model are
694  reflected correctly in callbacks and loop iterations.
695
696  We need to make sure we share the optimizers between the old and the new model
697  as well so that optimizer state is not lost if the user is running fit
698  multiple times.
699
700  Args:
701    model: Model to be replicated across Replicas
702    mode: Which of fit/eval/predict is building the distributed network
703    inputs: Input variables to be passed to the model
704    targets: Target tensor to be passed to model.compile
705
706  Returns:
707    A new model with shared layers with the old model.
708  """
709  # Need to do imports here since we run into a circular dependency error.
710  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
711  from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top
712
713  # We rely on the internal methods to avoid having share_weights weights in the
714  # public API.
715  if isinstance(model, sequential.Sequential):
716    updated_model = models._clone_sequential_model(
717        model, input_tensors=inputs, layer_fn=models.share_weights)
718  else:
719    updated_model = models._clone_functional_model(
720        model, input_tensors=inputs, layer_fn=models.share_weights)
721    # Callable losses added directly to a functional Model need to be added
722    # here.
723    updated_model._callable_losses = model._callable_losses
724
725  # Recast all low precision outputs back to float32 since we only casted
726  # the inputs to bfloat16 and not targets. This is done so that we can preserve
727  # precision when calculating the loss value.
728  def _upcast_low_precision_outputs(output):
729    if output.dtype == dtypes.bfloat16:
730      return math_ops.cast(output, dtypes.float32)
731    else:
732      return output
733  updated_model.outputs = [_upcast_low_precision_outputs(o)
734                           for o in updated_model.outputs]
735
736  if isinstance(targets, tuple):
737    targets = nest.flatten(targets)
738
739  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
740    _custom_compile_for_predict(updated_model)
741  else:
742    updated_model.compile(
743        model.optimizer,
744        model.loss,
745        metrics=metrics_module.clone_metrics(model._compile_metrics),
746        loss_weights=model.loss_weights,
747        sample_weight_mode=model.sample_weight_mode,
748        weighted_metrics=metrics_module.clone_metrics(
749            model._compile_weighted_metrics),
750        target_tensors=targets)
751  return updated_model
752
753
754def _build_distributed_network(model, strategy, mode, inputs=None,
755                               targets=None):
756  """Create a cloned model on each replica."""
757  with backend.get_graph().as_default(), strategy.scope():
758    distributed_model = strategy.extended.call_for_each_replica(
759        _build_network_on_replica,
760        args=(model, mode, inputs, targets))
761    set_distributed_model(model, mode, distributed_model)
762
763
764def _clone_and_build_model(model, mode, inputs=None, targets=None):
765  """Clone and build the given keras_model."""
766  # We need to set the import here since we run into a circular dependency
767  # error.
768  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
769  cloned_model = models.clone_model(model, input_tensors=inputs)
770
771  # Compile and build model.
772  if isinstance(model.optimizer, optimizers.TFOptimizer):
773    optimizer = model.optimizer
774  else:
775    optimizer_config = model.optimizer.get_config()
776    optimizer = model.optimizer.__class__.from_config(optimizer_config)
777
778  # Recast all low precision outputs back to float32 since we only casted
779  # the inputs to bfloat16 and not targets. This is done so that we can preserve
780  # precision when calculating the loss value.
781  def _upcast_low_precision_outputs(output):
782    if output.dtype == dtypes.bfloat16:
783      return math_ops.cast(output, dtypes.float32)
784    else:
785      return output
786  cloned_model.outputs = [_upcast_low_precision_outputs(o)
787                          for o in cloned_model.outputs]
788
789  if isinstance(targets, tuple):
790    targets = nest.flatten(targets)
791  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
792    _custom_compile_for_predict(cloned_model)
793  else:
794    cloned_model.compile(
795        optimizer,
796        model.loss,
797        metrics=metrics_module.clone_metrics(model._compile_metrics),
798        loss_weights=model.loss_weights,
799        sample_weight_mode=model.sample_weight_mode,
800        weighted_metrics=metrics_module.clone_metrics(
801            model._compile_weighted_metrics),
802        target_tensors=targets)
803  return cloned_model
804
805
806def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None):
807  """Create a cloned model on each replica."""
808  with backend.get_graph().as_default(), strategy.scope():
809    distributed_model = strategy.extended.call_for_each_replica(
810        _clone_and_build_model, args=(model, mode, inputs, targets))
811    set_distributed_model(model, mode, distributed_model)
812  if mode == ModeKeys.TRAIN:
813    model._make_callback_model(distributed_model)
814
815
816def _make_execution_function(model, mode):
817  """Makes or reuses function to run one step of distributed model execution."""
818  if is_distributing_by_cloning(model):
819    return _make_execution_function_with_cloning(model, mode)
820
821  distributed_function = get_distributed_function(model, mode)
822  if distributed_function:
823    return distributed_function
824
825  distribution_function = _make_execution_function_without_cloning(model, mode)
826  set_distributed_function(model, mode, distribution_function)
827  return distribution_function
828
829
830def _make_execution_function_without_cloning(model, mode):
831  """Creates a function to run one step of distributed model execution."""
832  strategy = model._distribution_strategy
833
834  with strategy.scope():
835    per_replica_function = _make_replica_execution_function(model, mode)
836
837    def distributed_function(input_fn):
838      """A single step of the distributed execution across replicas."""
839      x, y, sample_weights = input_fn()
840      # Call `Model.{train,test,predict}_on_batch` on every replica passing
841      # PerReplicas as arguments.  On every replica inside this call, each
842      # PerReplica object will return the value for that replica.  The outputs
843      # are PerReplicas too.
844      outputs = strategy.run(per_replica_function, args=(x, y, sample_weights))
845      # Out of PerReplica outputs reduce or pick values to return.
846      all_outputs = unwrap_outputs(
847          strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
848      return all_outputs
849
850    if not model.run_eagerly:
851      distributed_function = def_function.function(distributed_function)
852      def execution_function(input_fn):
853        # `numpy` translates Tensors to values in Eager mode.
854        return [out.numpy() for out in distributed_function(input_fn)]
855    else:
856      execution_function = distributed_function
857
858    return execution_function
859
860
861def _make_replica_execution_function(model, mode):
862  """A single step of the distributed execution on a replica."""
863  if mode == ModeKeys.TRAIN:
864    func = model.train_on_batch
865  elif mode == ModeKeys.TEST:
866    func = model.test_on_batch
867  else:
868
869    def predict_on_batch(x, y=None, sample_weights=None):
870      del y, sample_weights
871      return model.predict_on_batch(x)
872
873    func = predict_on_batch
874
875  if mode != ModeKeys.PREDICT:
876    # `reset_metrics` is set to False to maintain stateful metrics across
877    # batch-level calls.
878    func = functools.partial(func, reset_metrics=False)
879
880  return func
881
882
883def _make_replicated_models_with_cloning(model, mode):
884  """Build models on each replica."""
885  strategy = model._distribution_strategy
886
887  # If distributed_model is not built, create one for `mode`.
888  if model._compile_distribution:
889    clone_model_on_replicas(model, strategy, mode)
890  else:
891    _build_distributed_network(model, strategy, mode)
892
893
894def _make_execution_function_with_cloning(model, mode):
895  """Clones or re-uses models to run one step of distributed model execution."""
896  distributed_model = get_distributed_model(model, mode)
897  # TODO(b/134069401): Create a cache for the distributed model and exec
898  # function that incorporates additional attributes to be part of the cache key
899  # than just the mode.
900  # If distributed model for a particular `mode` is already built, use the
901  # `_distribution_function` on that distributed model.
902  # If you have updated the sample_weight_mode on the model, then you will need
903  # to recompile metrics and recreate the execution function. This is indicated
904  # by the `_recompile_exec_function` property.
905  if (distributed_model and hasattr(distributed_model, '_distribution_function')
906      and not (hasattr(distributed_model, '_recompile_exec_function') and
907               distributed_model._recompile_exec_function)):
908    return distributed_model._distributed_function
909
910  if not distributed_model:
911    _make_replicated_models_with_cloning(model, mode)
912    distributed_model = get_distributed_model(model, mode)
913  assert distributed_model
914
915  # Also create an execution function on that distributed model.
916  if context.executing_eagerly():
917    distributed_function = _make_eager_execution_function(model, mode)
918  else:
919    distributed_function = _make_graph_execution_function(model, mode)
920
921  # We cache the distributed execution function on the model since creating
922  # distributed models and execution functions are expensive.
923  distributed_model._distributed_function = distributed_function
924  distributed_model._recompile_exec_function = False
925  return distributed_function
926
927
928def _make_graph_execution_function(model, mode):
929  """Makes function to run one step of distributed model in graph mode."""
930
931  def _per_replica_function(model):
932    f = model._make_execution_function(mode)
933    return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)
934
935  strategy = model._distribution_strategy
936  with strategy.scope():
937    # Create train ops on each of the devices when we call
938    # `_per_replica_fit_function`.
939    (grouped_inputs, grouped_outputs, grouped_updates,
940     grouped_session_args) = strategy.extended.call_for_each_replica(
941         _per_replica_function, args=(get_distributed_model(model, mode),))
942
943    # Initialize the variables in the replicated model. This is necessary for
944    # multi-worker training because on some workers, initialization is not
945    # needed. This method does initialization or waiting for initialization
946    # according to the context object of distribute coordinator.
947    init_restore_or_wait_for_variables()
948
949    # Unwrap all the per device values returned from `call_for_each_replica`.
950    # Unwrapping per device values gives you a list of values that can be
951    # used to construct a new train function that is composed of update ops on
952    # all the devices over which the model is distributed.
953    (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values(
954        strategy,
955        grouped_inputs,
956        grouped_outputs,
957        grouped_updates,
958        grouped_session_args,
959        with_loss_tensor=(mode != ModeKeys.PREDICT))
960
961    return backend.function(
962        all_inputs,
963        all_outputs,
964        updates=all_updates,
965        name='distributed_{}_function'.format(mode),
966        **all_session_args)
967
968
969def _make_eager_execution_function(model, mode):
970  """Makes function to run one step of distributed model eager execution."""
971  def _per_replica_function(model):
972    f = model._make_execution_function(mode)
973    return (f.inputs, f.outputs)
974
975  # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
976  # the global one.
977  strategy = model._distribution_strategy
978  global_graph = backend.get_graph()
979
980  with global_graph.as_default(), strategy.scope():
981    # First we gather the relevant portions of the model across all replicas.
982    # `backend._scratch_graph(global_graph)` signals to Keras that it should not
983    # lift to a separate graph when creating the per-replica functions.
984    with backend._scratch_graph(global_graph):
985      # Create train ops on each of the devices when we call
986      # `_per_replica_fit_function`.
987      grouped = strategy.extended.call_for_each_replica(
988          _per_replica_function, args=(get_distributed_model(model, mode),))
989      grouped_inputs, grouped_outputs = grouped
990
991      # Unwrap all the per device values returned from `call_for_each_replica`.
992      # Unwrapping per device values gives you a list of values that can be
993      # used to construct a new train function that is composed of
994      # inputs/outputs on all the devices over which the model is distributed.
995      (all_inputs, all_outputs, _, _) = unwrap_values(
996          strategy,
997          grouped_inputs,
998          grouped_outputs,
999          with_loss_tensor=(mode != ModeKeys.PREDICT))
1000
1001    # Finally, a joint Keras function is created; this one will be created in
1002    # a separate FuncGraph.
1003    return backend.function(
1004        all_inputs,
1005        all_outputs,
1006        name='eager_distributed_{}_function'.format(mode))
1007
1008
1009def _copy_weights_to_distributed_model(original_model, mode):
1010  """Copies weights from original model to distributed models."""
1011  strategy = original_model._distribution_strategy
1012  distributed_model = get_distributed_model(original_model, mode)
1013  if strategy:
1014    # Copy the weights from the original model to each of the replicated
1015    # models.
1016    orig_model_weights = original_model.get_weights()
1017    first_model = strategy.unwrap(distributed_model)[0]
1018    set_weights(strategy, first_model, orig_model_weights)
1019
1020
1021def _copy_weights_to_original_model(model, mode):
1022  """Copies weights from first distributed model back to original model."""
1023  if model._distribution_strategy and mode == ModeKeys.TRAIN:
1024    distributed_model = get_distributed_model(model, mode)
1025    updated_weights = model._distribution_strategy.unwrap(
1026        distributed_model)[0].get_weights()
1027    model.set_weights(updated_weights)
1028
1029
1030def _per_replica_aggregate_batch(strategy, batch_outs, model, mode):
1031  """Aggregates the per-replica batch-level outputs from a distributed step."""
1032  if strategy is not None and mode == ModeKeys.PREDICT:
1033    total_batch_outs = []
1034    for i in range(len(model.outputs)):
1035      num_replicas = strategy.num_replicas_in_sync
1036      nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
1037      total_batch_outs.append(
1038          concat_along_batch_dimension(nest.flatten(nested_outs)))
1039    return total_batch_outs
1040  return batch_outs
1041
1042
1043def _reset_metrics(model):
1044  if model._distribution_strategy:
1045    for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]:
1046      distributed_model = get_distributed_model(model, mode)
1047      if distributed_model:
1048        first_model = model._distribution_strategy.unwrap(distributed_model)[0]
1049        first_model.reset_metrics()
1050
1051
1052def get_distributed_model(model, mode):
1053  key = _generate_cache_key(mode)
1054  return model._distributed_model_cache.get(key, None)
1055
1056
1057def set_distributed_model(model, mode, distributed_model):
1058  key = _generate_cache_key(mode)
1059  model._distributed_model_cache[key] = distributed_model
1060
1061
1062def get_distributed_function(model, mode):
1063  key = _generate_cache_key(mode)
1064  return model._distributed_function_cache.get(key, None)
1065
1066
1067def set_distributed_function(model, mode, distributed_function):
1068  key = _generate_cache_key(mode)
1069  model._distributed_function_cache[key] = distributed_function
1070
1071
1072def _generate_cache_key(mode):
1073  key = hash(mode)
1074  return key
1075
1076
1077@tf_contextlib.contextmanager
1078def distributed_scope(strategy, learning_phase):
1079  with strategy.scope(), backend.learning_phase_scope(learning_phase):
1080    yield
1081
1082
1083def is_current_worker_chief():
1084  return dc.get_current_worker_context().is_chief
1085
1086
1087def filter_distributed_callbacks(callbacks_list, model):
1088  """Filter Callbacks based on the worker context when running multi-worker.
1089
1090  Args:
1091    callbacks_list: A list of `Callback` instances.
1092    model: Keras model instance.
1093
1094  Returns:
1095    The list of `Callback` instances that should be run on this worker.
1096  """
1097
1098  if not model._in_multi_worker_mode():
1099    raise ValueError(
1100        'filter_distributed_callbacks() should only be called when Keras '
1101        'is in multi worker mode.')
1102
1103  callbacks_list = callbacks_list or []
1104  if not [
1105      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
1106  ]:
1107    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
1108    # fails to (possibly with tempfile directory).
1109    logging.warning('ModelCheckpoint callback is not provided. '
1110                    'Workers will need to restart training if any fails.')
1111
1112  if callbacks_list is None or is_current_worker_chief():
1113    return callbacks_list
1114
1115  # Some Callbacks should only run on the chief worker.
1116  return [
1117      callback for callback in callbacks_list if not callback._chief_worker_only
1118  ]  # pylint: disable=protected-access
1119
1120
1121def _update_sample_weight_modes(model, mode, sample_weights):
1122  """Update sample_weight_mode of the distributed model."""
1123  if is_distributing_by_cloning(model):
1124    distributed_model = get_distributed_model(model, mode)
1125    if not distributed_model:
1126      _make_replicated_models_with_cloning(model, mode)
1127      distributed_model = get_distributed_model(model, mode)
1128    distributed_model._recompile_exec_function = any(
1129        [e.sample_weights_mismatch() for e in model._training_endpoints])
1130
1131    if sample_weights:
1132      distributed_models = flatten_per_replica_values(
1133          model._distribution_strategy, distributed_model)
1134      # sample_weights is a tuple of 1 list where the number of elements in the
1135      # list is equal to the number of replicas in sync.
1136      sample_weights = sample_weights[0]
1137      if sample_weights and None not in sample_weights:
1138        for m, sw in zip(distributed_models, sample_weights):
1139          m._update_sample_weight_modes(sample_weights=[sw])
1140
1141
1142def concat_along_batch_dimension(outputs):
1143  """Concats prediction outputs along the batch dimension."""
1144  if isinstance(outputs[0], sparse_tensor.SparseTensor):
1145    return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs)
1146  if isinstance(outputs[0], ragged_tensor.RaggedTensor):
1147    return array_ops.concat(outputs, axis=0)
1148  return np.concatenate(outputs)
1149