xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/values_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions used by values.py and ps_values.py."""
16
17from tensorflow.python.distribute import distribute_lib
18from tensorflow.python.distribute import distribution_strategy_context as ds_context
19from tensorflow.python.distribute import reduce_util
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import variable_scope as vs
26from tensorflow.python.saved_model import save_context
27from tensorflow.python.saved_model import save_options
28from tensorflow.python.training.saving import saveable_object
29
30
31def write_object_proto(var, proto, options):
32  """Update a SavedObject proto for the caller.
33
34  If a DistributedVariable object supports this method, it will be called when
35  saving with a pre-built `SavedObject` proto representing the object, plus an
36  instance of `SaveOptions`. This method is then free to modify that proto
37  instance.
38
39  `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
40   write out information about their components to the
41   `experimental_distributed_variable_components` field of a
42   `SavedVariable` (depending on the `SaveOptions` variable policy).
43
44  Args:
45    var: The DistributedVariable object.
46    proto: A pre-built `SavedObject` proto for this object. It is assumed this
47      will be a `SavedVariable` instance.
48    options: A `SaveOptions` instance.
49  """
50  if options.experimental_variable_policy._expand_distributed_variables(  # pylint: disable=protected-access
51  ):
52    for var in var.values:
53      var_proto = (
54          proto.variable.experimental_distributed_variable_components.add())
55      var_proto.name = var.name.split(":")[0]
56      var_proto.device = var.device
57
58
59def get_on_write_saveable(var, primary_var, name):
60  """Return saveable spec for AUTO and ON_WRITE variables."""
61  # We use a callable so that we don't have to evaluate this expression
62  # in the case where we are trying to restore instead of save.
63  def tensor():
64    if context.executing_eagerly() and not primary_var.is_initialized():
65      # A SaveSpec tensor value of `None` indicates that the variable is
66      # uninitialized.
67      return None
68    strategy = var.distribute_strategy
69    return strategy.extended.read_var(var)
70
71  spec = saveable_object.SaveSpec(
72      tensor=tensor,
73      slice_spec="",
74      name=name,
75      dtype=var.dtype,
76      device=primary_var.device)
77
78  return tensor, [spec]
79
80
81def get_on_write_restore_ops(var, tensor):
82  """Return restore ops for AUTO and ON_WRITE variables."""
83  packed_var = var._packed_variable  # pylint: disable=protected-access
84  if packed_var is not None:
85    return control_flow_ops.group(
86        tuple(
87            assign_on_device(d, packed_var, tensor)
88            for d in packed_var.devices))
89  return control_flow_ops.group(
90      tuple(
91          assign_on_device(v.device, v, tensor)
92          for v in var.values))
93
94
95def get_on_read_saveable(var, primary_var, name):
96  """Return saveables for ON_READ variable."""
97
98  # We use a callable so that we don't have to evaluate this expression
99  # in the case where we are trying to restore instead of save.
100  def tensor():
101    return var._get_cross_replica()  # pylint: disable=protected-access
102
103  spec = saveable_object.SaveSpec(
104      tensor=tensor,
105      slice_spec="",
106      name=name,
107      dtype=var.dtype,
108      device=primary_var.device)
109
110  return tensor, [spec]
111
112
113def get_on_read_restore_ops(var, tensor, aggregation):
114  """Return restore ops for ON_READ variables."""
115  # To preserve the sum across save and restore, we have to divide the
116  # total across all devices when restoring a variable that was summed
117  # when saving.
118  if aggregation == vs.VariableAggregation.SUM:
119    strategy = var.distribute_strategy
120    tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
121                           var.dtype)
122  return control_flow_ops.group(
123      tuple(
124          assign_on_device(v.device, v, tensor)
125          for v in var.values))
126
127
128# Utility function that indicates if you are in an UpdateContext when running
129# in a replica fn.
130def in_replica_update_context():
131  return distribute_lib.get_update_replica_id() is not None
132
133
134def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
135  assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
136  return var._update(  # pylint: disable=protected-access
137      update_fn=assign_fn,
138      value=value,
139      use_locking=use_locking,
140      name=name,
141      read_value=read_value)
142
143
144def on_write_assign_add(var, value, use_locking=False, name=None,
145                        read_value=True):
146  assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
147  return var._update(  # pylint: disable=protected-access
148      update_fn=assign_add_fn,
149      value=value,
150      use_locking=use_locking,
151      name=name,
152      read_value=read_value)
153
154
155def on_write_assign_sub(var, value, use_locking=False, name=None,
156                        read_value=True):
157  assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
158  return var._update(  # pylint: disable=protected-access
159      update_fn=assign_sub_fn,
160      value=value,
161      use_locking=use_locking,
162      name=name,
163      read_value=read_value)
164
165
166def assign_on_each_device(var, assign_func, value, read_value):
167  """Update the variable on each replica with the given assign_func and value."""
168  if var._packed_variable is not None:  # pylint: disable=protected-access
169    update = control_flow_ops.group(
170        tuple(
171            assign_func(d, var._packed_variable, value) for d in var._devices))  # pylint: disable=protected-access
172  else:
173    update = control_flow_ops.group(
174        tuple(assign_func(v.device, v, value) for v in var._values))  # pylint: disable=protected-access
175  if not read_value:
176    return update
177  with ops.control_dependencies([update] if update else []):
178    return var.read_value()
179
180
181def on_read_assign_sub_cross_replica(var, value, read_value=True):
182  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
183    if ds_context.in_cross_replica_context():
184      if var.aggregation == vs.VariableAggregation.SUM:
185        raise ValueError(
186            "SyncOnReadVariable does not support `assign_sub` in "
187            "cross-replica context when aggregation is set to "
188            "`tf.VariableAggregation.SUM`.")
189      return assign_on_each_device(var, assign_sub_on_device,
190                                   value, read_value)
191
192
193def on_read_assign_add_cross_replica(var, value, read_value=True):
194  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
195    if ds_context.in_cross_replica_context():
196      if var.aggregation == vs.VariableAggregation.SUM:
197        raise ValueError(
198            "SyncOnReadVariable does not support `assign_add` in "
199            "cross-replica context when aggregation is set to "
200            "`tf.VariableAggregation.SUM`.")
201      return assign_on_each_device(var, assign_add_on_device,
202                                   value, read_value)
203
204
205def on_read_assign_cross_replica(var, value, read_value=True):
206  """Return the value of the variable in cross replica context."""
207  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
208    if ds_context.in_cross_replica_context():
209      # To preserve the sum across save and restore, we have to divide the
210      # total across all devices when restoring a variable that was summed
211      # when saving.
212      tensor = value
213      if var.aggregation == vs.VariableAggregation.SUM:
214        strategy = var._distribute_strategy  # pylint: disable=protected-access
215        tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
216                               var.dtype)
217      return assign_on_each_device(var, assign_on_device, tensor,
218                                   read_value)
219
220
221def scatter_sub(var, sparse_delta, use_locking=False, name=None):
222  scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
223  return var._update(  # pylint: disable=protected-access
224      update_fn=scatter_sub_fn,
225      value=sparse_delta,
226      use_locking=use_locking,
227      name=name)
228
229
230def scatter_add(var, sparse_delta, use_locking=False, name=None):
231  scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
232  return var._update(  # pylint: disable=protected-access
233      update_fn=scatter_add_fn,
234      value=sparse_delta,
235      use_locking=use_locking,
236      name=name)
237
238
239def scatter_mul(var, sparse_delta, use_locking=False, name=None):
240  scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
241  return var._update(  # pylint: disable=protected-access
242      update_fn=scatter_mul_fn,
243      value=sparse_delta,
244      use_locking=use_locking,
245      name=name)
246
247
248def scatter_div(var, sparse_delta, use_locking=False, name=None):
249  scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
250  return var._update(  # pylint: disable=protected-access
251      update_fn=scatter_div_fn,
252      value=sparse_delta,
253      use_locking=use_locking,
254      name=name)
255
256
257def scatter_min(var, sparse_delta, use_locking=False, name=None):
258  scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
259  return var._update(  # pylint: disable=protected-access
260      update_fn=scatter_min_fn,
261      value=sparse_delta,
262      use_locking=use_locking,
263      name=name)
264
265
266def scatter_max(var, sparse_delta, use_locking=False, name=None):
267  scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
268  return var._update(  # pylint: disable=protected-access
269      update_fn=scatter_max_fn,
270      value=sparse_delta,
271      use_locking=use_locking,
272      name=name)
273
274
275def scatter_update(var, sparse_delta, use_locking=False, name=None):
276  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
277  return var._update(  # pylint: disable=protected-access
278      update_fn=scatter_update_fn,
279      value=sparse_delta,
280      use_locking=use_locking,
281      name=name)
282
283
284def get_current_replica_id_as_int():
285  """Returns the current replica ID as an integer, or `None`."""
286  replica_context = ds_context.get_replica_context()
287  if replica_context:
288    replica_id = replica_context._replica_id  # pylint: disable=protected-access
289    if not isinstance(replica_id, int):
290      replica_id = tensor_util.constant_value(replica_id)
291  else:
292    replica_id = distribute_lib.get_update_replica_id()
293  return replica_id
294
295
296def assign_on_device(device, variable, tensor):
297  with ops.device(device):
298    return variable.assign(tensor)
299
300
301def assign_add_on_device(device, variable, tensor):
302  with ops.device(device):
303    return variable.assign_add(tensor)
304
305
306def assign_sub_on_device(device, variable, tensor):
307  with ops.device(device):
308    return variable.assign_sub(tensor)
309
310
311def assert_replica_context(strategy):
312  replica_context = ds_context.get_replica_context()
313  if not replica_context:
314    raise RuntimeError(
315        "Replica-local variables may only be assigned in a replica context.")
316  if replica_context.strategy is not strategy:
317    raise RuntimeError(
318        "Replica-local variables may only be assigned in a replica context.")
319
320
321def apply_aggregation(strategy, value, aggregation, destinations):
322  if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
323    return strategy.extended.broadcast_to(
324        strategy.experimental_local_results(value)[0],
325        destinations=destinations)
326  reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
327  return strategy.extended.reduce_to(reduce_op, value, destinations)
328
329
330aggregation_error_msg = (
331    "You must specify an aggregation method to update a "
332    "{variable_type} in Replica Context. You can do so by passing "
333    "an explicit value for argument `aggregation` to tf.Variable(..)."
334    "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
335    "`tf.VariableAggregation` lists the possible aggregation methods."
336    "This is required because {variable_type} should always be "
337    "kept in sync. When updating them or assigning to them in a "
338    "replica context, we automatically try to aggregate the values "
339    "before updating the variable. For this aggregation, we need to "
340    "know the aggregation method. "
341    "Another alternative is to not try to update such "
342    "{variable_type} in replica context, but in cross replica "
343    "context. You can enter cross replica context by calling "
344    "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
345    "Inside `merge_fn`, you can then update the {variable_type} "
346    "using `tf.distribute.StrategyExtended.update()`.")
347
348
349scatter_error_msg = ("{op_name} is only supported for mirrored "
350                     "variable (variable created within certain "
351                     "`tf.distribute.Strategy` scope) with NONE or "
352                     "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")
353
354
355def is_saving_non_distributed():
356  """Returns whether we're saving a non-distributed version of the model.
357
358  It returns True iff we are in saving context and are saving a non-distributed
359  version of the model. That is, SaveOptions.experimental_variable_policy is
360  NONE.
361
362  Returns:
363    A boolean.
364  """
365  if not save_context.in_save_context():
366    return False
367  options = save_context.get_save_options()
368  return (options.experimental_variable_policy !=
369          save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
370
371
372def mark_as_unsaveable():
373  """Marks the function as unsaveable if not inside save context."""
374  if ops.inside_function() and not save_context.in_save_context():
375    ops.get_default_graph().mark_as_unsaveable("""
376ConcreteFunction that uses distributed variables in certain way cannot be saved.
377If you're saving with
378
379tf.saved_model.save(..., signatures=f.get_concrete_function())
380
381do
382
383@tf.function(input_signature=...)
384def f_with_input_signature():
385  ...
386
387tf.saved_model.save(..., signatures=f_with_input_signature)`
388
389instead.""")
390