1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities related to distributed training.""" 16# pylint:disable=protected-access 17 18from tensorflow.python.distribute import distribution_strategy_context as ds_context 19from tensorflow.python.distribute import values as values_lib 20from tensorflow.python.keras import backend 21from tensorflow.python.ops import variables 22 23 24# TODO(b/118776054): Currently we support global batch size for TPUStrategy and 25# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is 26# no longer needed. 27def global_batch_size_supported(distribution_strategy): 28 return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access 29 30 31def call_replica_local_fn(fn, *args, **kwargs): 32 """Call a function that uses replica-local variables. 33 34 This function correctly handles calling `fn` in a cross-replica 35 context. 36 37 Args: 38 fn: The function to call. 39 *args: Positional arguments to the `fn`. 40 **kwargs: Keyword argument to `fn`. 41 42 Returns: 43 The result of calling `fn`. 44 """ 45 # TODO(b/132666209): Remove this function when we support assign_* 46 # for replica-local variables. 47 strategy = None 48 if 'strategy' in kwargs: 49 strategy = kwargs.pop('strategy') 50 else: 51 if ds_context.has_strategy(): 52 strategy = ds_context.get_strategy() 53 54 # TODO(b/120571621): TPUStrategy does not implement replica-local variables. 55 is_tpu = backend.is_tpu_strategy(strategy) 56 if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()): 57 with strategy.scope(): 58 return strategy.extended.call_for_each_replica(fn, args, kwargs) 59 return fn(*args, **kwargs) 60 61 62def is_distributed_variable(v): 63 """Returns whether `v` is a distributed variable.""" 64 return (isinstance(v, values_lib.DistributedValues) and 65 isinstance(v, variables.Variable)) 66