xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/distribute/distributed_training_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
16# pylint:disable=protected-access
17
18from tensorflow.python.distribute import distribution_strategy_context as ds_context
19from tensorflow.python.distribute import values as values_lib
20from tensorflow.python.keras import backend
21from tensorflow.python.ops import variables
22
23
24# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
25# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
26# no longer needed.
27def global_batch_size_supported(distribution_strategy):
28  return distribution_strategy.extended._global_batch_size  # pylint: disable=protected-access
29
30
31def call_replica_local_fn(fn, *args, **kwargs):
32  """Call a function that uses replica-local variables.
33
34  This function correctly handles calling `fn` in a cross-replica
35  context.
36
37  Args:
38    fn: The function to call.
39    *args: Positional arguments to the `fn`.
40    **kwargs: Keyword argument to `fn`.
41
42  Returns:
43    The result of calling `fn`.
44  """
45  # TODO(b/132666209): Remove this function when we support assign_*
46  # for replica-local variables.
47  strategy = None
48  if 'strategy' in kwargs:
49    strategy = kwargs.pop('strategy')
50  else:
51    if ds_context.has_strategy():
52      strategy = ds_context.get_strategy()
53
54  # TODO(b/120571621): TPUStrategy does not implement replica-local variables.
55  is_tpu = backend.is_tpu_strategy(strategy)
56  if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
57    with strategy.scope():
58      return strategy.extended.call_for_each_replica(fn, args, kwargs)
59  return fn(*args, **kwargs)
60
61
62def is_distributed_variable(v):
63  """Returns whether `v` is a distributed variable."""
64  return (isinstance(v, values_lib.DistributedValues) and
65          isinstance(v, variables.Variable))
66