xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/distribute_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing utilities used by tf.distribute.Strategy."""
16
17from collections import abc
18import contextlib
19import threading
20
21import contextlib
22import threading
23from tensorflow.python.distribute import tpu_values as tpu_values_lib
24from tensorflow.python.distribute import values as values_lib
25from tensorflow.python.eager import context
26from tensorflow.python.eager import tape
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import variable_scope as vs
32from tensorflow.python.util import nest
33
34
35def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
36  """Makes a nest per-replica into a nest of PerReplica/Mirrored values.
37
38  Args:
39    values: Values to regroup
40    wrap_class: Class that `values` be wrapped in.
41    always_wrap: Always wrap the `values` in `wrap_class` even if the values
42        are the same except for DistributeVariable.
43  Returns:
44    Wrapped `values`.
45  """
46  v0 = values[0]
47
48  if isinstance(v0, list):
49    for v in values[1:]:
50      assert isinstance(v, list)
51      assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
52                                 (len(v), len(v0), v, v0))
53    return [
54        regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
55        for i in range(len(v0))
56    ]
57
58  if isinstance(v0, tuple):
59    for v in values[1:]:
60      assert isinstance(v, tuple)
61      assert len(v) == len(v0), ("Values to regroup had different lengths: "
62                                 f"len(v) == {len(v)}, len(v0) == {len(v0)}, "
63                                 f"v: {v}, v0: {v0}")
64    regrouped_tuple = tuple(
65        regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
66        for i in range(len(v0)))
67    if hasattr(v0, "_fields"):
68      # This tuple is in fact a namedtuple! Create a new namedtuple instance
69      # and initialize it with the regrouped values:
70      assert hasattr(v0, "_make")
71      return v0._make(regrouped_tuple)
72    else:
73      return regrouped_tuple
74
75  if isinstance(v0, abc.Mapping):
76    v0keys = v0.keys()
77    for v in values[1:]:
78      assert isinstance(v, abc.Mapping), ("v[0]: %r  v[i]: %r" % (v0, v))
79      assert set(v.keys()) == set(v0keys), ("v[0].keys: %s  v[i].keys: %s" %
80                                            (set(v0keys), set(v.keys())))
81    # Use the actual type in case it is a class inherited from a dict.
82    return type(v0)({
83        key: regroup(tuple(v[key] for v in values),
84                     wrap_class, always_wrap)
85        for key in v0keys
86    })
87
88  # If exactly the same object across all devices, return it unwrapped.
89  same_id = True
90  for v in values[1:]:
91    if v is not v0:
92      same_id = False
93      break
94  # Consider three cases where same_id is true:
95  # * If v0 is a DistributedVariable (a MirroredVariable or
96  #   SyncOnReadVariable, and same_id means it is the same across all
97  #   devices), we want to return it. We check DistributedVariable
98  #   specifically since it can look like it has a
99  #   _distributed_container member since its members do.
100  if same_id and isinstance(v0, values_lib.DistributedVariable):
101    return v0
102  # * If v0 is a member of a distributed variable, in which case
103  #   hasattr(v0, "_distributed_container") is true, we want to
104  #   return the DistributedVariable that contains it using the
105  #   _distributed_container logic below. This case can trigger
106  #   same_id when there is only one device.
107  # * In any other situation, same_id means we return v0 unless `always_wrap` is
108  #   true.
109  if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
110    return v0
111
112  # Detect the case where each device has a parallel component of the
113  # same MirroredVariable (or SyncOnReadVariable). In this case we
114  # want to return the containing MirroredVariable, after a bunch of
115  # sanity checking. In particular, each component should have the
116  # same container, and the devices of the variables should match the
117  # keys of the per-replica dictionary.
118  if hasattr(v0, "_distributed_container"):
119    # pylint: disable=protected-access
120    assert not isinstance(v0, values_lib.MirroredVariable), (
121        "ids = %s, values = %s" % ([id(v) for v in values], values))
122    distributed_container = v0._distributed_container()
123    assert distributed_container is not None
124    for v in values[1:]:
125      assert distributed_container is v._distributed_container()
126    return distributed_container
127  # pylint: enable=protected-access
128
129  return wrap_class(values)
130
131
132def select_replica(replica_id, structured):
133  """Specialize a nest of regular & per-replica values for one replica."""
134
135  def _get(x):
136    # `DistributedValues` would be sliced according to replica unless it is a
137    # `DistributedVariable` because `DistributedVariable` can be handled
138    # directly in the replica context.
139    if (isinstance(x, values_lib.DistributedVariable) or
140        not isinstance(x, values_lib.DistributedValues)):
141      return x
142    else:
143      return x.values[replica_id]
144
145  return nest.map_structure(_get, structured)
146
147
148def select_replica_mirrored(replica_id, structured):
149  """Specialize a nest of regular & mirrored values for one replica."""
150  assert_mirrored(structured)
151  return select_replica(replica_id, structured)
152
153
154def assert_mirrored(structured):
155  """Raises if the structured is not composed of mirrored or regular values."""
156
157  def _assert_mirrored(x):
158    if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x):
159      raise TypeError(
160          "Expected value to be mirrored across replicas: %s in %s." %
161          (x, structured))
162
163  nest.map_structure(_assert_mirrored, structured)
164
165
166def update_regroup(extended, updates, group):
167  """Regroup for an update, with dependencies to ensure all updates execute."""
168  if not group:
169    regrouped = regroup(updates, values_lib.Mirrored)
170    return nest.map_structure(extended._local_results, regrouped)  # pylint: disable=protected-access
171
172  def _make_grouped_mirrored(values):
173    """Convert per-replica list `values` into Mirrored type with grouping."""
174    if len(values) == 1:
175      return values_lib.Mirrored(values)
176
177    # Make sure we run all updates. Without this, something like
178    # session.run(extended.update(...)) may only update one replica.
179    g = control_flow_ops.group(values)
180
181    # If values is just ops, the grouping is enough. Everything in values
182    # should have the same type, since we expect every replica to be performing
183    # the same computation.
184    if not all(tensor_util.is_tf_type(v) for v in values):
185      return g
186
187    # Otherwise we need tensors with the same values as `values`, but
188    # that have a dependency on `g`.
189    with_dep = []
190    for v in values:
191      with ops.device(v.device), ops.control_dependencies([g]):
192        with_dep.append(array_ops.identity(v))
193
194    return values_lib.Mirrored(with_dep)
195
196  return regroup(updates, _make_grouped_mirrored)
197
198
199def value_container(val):
200  """Returns the container that this per-replica `value` belongs to.
201
202  Args:
203    val: A value returned by `call_for_each_replica()` or a variable created in
204      `scope()`.
205
206  Returns:
207    A container that `value` belongs to.
208    If value does not belong to any container (including the case of
209    container having been destroyed), returns the value itself.
210  """
211  if (hasattr(val, "_distributed_container") and
212      # DistributedVariable has _distributed_container defined
213      # but we don't want to return it.
214      not isinstance(val, values_lib.DistributedVariable)):
215    container = val._distributed_container()  # pylint: disable=protected-access
216    if container is not None:
217      return container
218  return val
219
220
221def is_distributed_variable(v):
222  """Determine if a variable is ds variable or TPU mirrored variable."""
223  return getattr(v, "is_distributed_variable", False)
224
225
226def is_distributed_table(v):
227  """Determine if an object is a DistributedTable."""
228  return getattr(v, "is_distributed_table", False)
229
230
231def _validate_colocate_extended(v, extended):
232  variable_strategy = v._distribute_strategy  # pylint: disable=protected-access
233  if variable_strategy.extended is not extended:
234    raise ValueError(
235        "`colocate_vars_with` must only be passed a variable created in this "
236        "tf.distribute.Strategy.scope(), not %s created in scope: %s" %
237        (v, variable_strategy))
238
239
240def validate_colocate_distributed_variable(v, extended):
241  if not isinstance(v, values_lib.DistributedVariable):
242    raise ValueError(
243        "`colocate_vars_with` must only be passed a variable created in this "
244        "tf.distribute.Strategy.scope(), not: %r" % (v,))
245  _validate_colocate_extended(v, extended)
246
247
248def validate_colocate(v, extended):
249  if not hasattr(v, "_distribute_strategy"):
250    raise ValueError(
251        "`colocate_vars_with` must only be passed a variable created in this "
252        "tf.distribute.Strategy.scope(), not: %r" % (v,))
253  _validate_colocate_extended(v, extended)
254
255
256# Variable creation function for sync strategies.
257def _validate_synchronization(kwargs):
258  """Validate that given synchronization value is valid."""
259  synchronization = kwargs.get("synchronization",
260                               vs.VariableSynchronization.AUTO)
261  if synchronization == vs.VariableSynchronization.NONE:
262    raise ValueError(
263        "`NONE` variable synchronization mode is not supported with "
264        "tf.distribute strategy. Please change the `synchronization` for "
265        "variable: " + str(kwargs["name"]))
266  if synchronization not in (vs.VariableSynchronization.ON_READ,
267                             vs.VariableSynchronization.ON_WRITE,
268                             vs.VariableSynchronization.AUTO):
269    raise ValueError(
270        "Invalid variable synchronization mode: %s for variable: %s" %
271        (synchronization, kwargs["name"]))
272  if synchronization == vs.VariableSynchronization.AUTO:
273    return vs.VariableSynchronization.ON_WRITE
274  return synchronization
275
276
277def _validate_aggregation(kwargs):
278  aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE)
279
280  if aggregation not in (vs.VariableAggregation.NONE,
281                         vs.VariableAggregation.SUM,
282                         vs.VariableAggregation.MEAN,
283                         vs.VariableAggregation.ONLY_FIRST_REPLICA):
284    raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
285                     (aggregation, kwargs["name"]))
286  return aggregation
287
288
289def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
290                             policy_mapping, **kwargs):
291  """Create distributed variables with given synchronization and aggregation."""
292  # Figure out what collections this variable should be added to.
293  # We'll add the MirroredVariable to those collections instead.
294  var_collections = kwargs.pop("collections", None)
295  if var_collections is None:
296    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
297  kwargs["collections"] = []
298
299  synchronization = _validate_synchronization(kwargs)
300  # Update synchronization in kwargs in case it's AUTO, which is converted to
301  # ON_WRITE.
302  kwargs["synchronization"] = synchronization
303  aggregation = _validate_aggregation(kwargs)
304  use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
305
306  # Ignore user-specified caching device, not needed for mirrored variables.
307  kwargs.pop("caching_device", None)
308
309  # TODO(josh11b,apassos): It would be better if variable initialization
310  # was never recorded on the tape instead of having to do this manually
311  # here.
312  with tape.stop_recording():
313    value_list = real_mirrored_creator(**kwargs)
314    # MirroredVariable is recreated during saved_model loading, and its
315    # component variables (value_list) will have None initializer. We
316    # set their initializers to no_op so that consumer like
317    # `global_variables_initializer` wouldn't complain, as it groups all
318    # variables' initializers thus all variables have to have initializers.
319    for v in value_list:
320      # pylint:disable=protected-access
321      if hasattr(v, "_initializer_op") and v._initializer_op is None:
322        v._initializer_op = control_flow_ops.no_op()
323      # pylint:enable=protected-access
324    if use_var_policy:
325      var_policy_cls = policy_mapping.get(synchronization)
326      var_policy = var_policy_cls(aggregation=aggregation)
327      var_cls = class_mapping.get("VariableClass")
328      result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
329    else:
330      var_cls = class_mapping.get(synchronization)
331      result = var_cls(strategy, value_list, aggregation)
332
333  # Add the wrapped variable to the requested collections.
334  # The handling of eager mode and the global step matches
335  # ResourceVariable._init_from_args().
336  if not context.executing_eagerly():
337    g = ops.get_default_graph()
338    # If "trainable" is True, next_creator() will add the member variables
339    # to the TRAINABLE_VARIABLES collection, so we manually remove
340    # them and replace with the MirroredVariable. We can't set
341    # "trainable" to False for next_creator() since that causes functions
342    # like implicit_gradients to skip those variables.
343    if kwargs.get("trainable", True):
344      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
345      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
346      for value in value_list:
347        for i, trainable_variable in enumerate(l):
348          if value is trainable_variable:
349            del l[i]
350            break
351
352    g.add_to_collections(var_collections, result)
353  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
354    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
355
356  return result
357
358
359# Utility functions
360# Return True if the Value is Mirrored or the Variable is replicated and kept in
361# sync.
362def is_mirrored(val):
363  if isinstance(val, values_lib.DistributedVariable):
364    if val._policy:  # pylint: disable=protected-access
365      return val._policy._is_mirrored()  # pylint: disable=protected-access
366  return isinstance(val, values_lib.Mirrored)
367
368
369def is_sync_on_read(val):
370  if isinstance(val, values_lib.DistributedVariable):
371    if val._policy:  # pylint: disable=protected-access
372      return not val._policy._is_mirrored()  # pylint: disable=protected-access
373  return not isinstance(val, values_lib.Mirrored)
374
375
376class CachingScopeLocal(threading.local):
377  """Class for maintaining thread local state for caching scope."""
378
379  def __init__(self):
380    super(CachingScopeLocal, self).__init__()
381    self.new_cache_scope_count = 0
382    self.cache_scope_exited_count = 0
383
384  def enter_scope(self):
385    self.new_cache_scope_count += 1
386
387  def exit_scope(self):
388    self.cache_scope_exited_count += 1
389
390  def in_caching_scope(self):
391    return self.new_cache_scope_count > self.cache_scope_exited_count
392
393
394caching_scope_local = CachingScopeLocal()
395
396
397@contextlib.contextmanager
398def cache_variable_reads():
399  """Scope for caching variable reads for AggregatingVariable.
400
401  The variable reads for AggregatingVariable inside this scope are cached. i.e.
402  the first read of variable reads the value from possibly remote handle, but
403  subsequent reads are returned using local cached value.
404
405  For example:
406  strategy = ParameterServerStrategy...
407  with strategy.scope():
408    # Variable v is of AggregatingVariable type with actual variable residing
409    # on PS.
410    v = tf.Variable(1.0)
411
412  with distribute_utils.cache_variable_reads():
413    v.read_value()  # Reads value 1.0
414    v.assign(constant_op.constant(5.0))  # v changes to 5.0
415    t1 = v.read_value()
416    t2 = v.read_value()  # Both t1 & t2 return cached value 1.0 from local CPU.
417
418  Notes about cache_variable_reads scope:
419  1. Nesting of scope cache_variable_reads() is not supported
420  2. And when caching scope is enabled, the thread enabling the cache and
421    mirrored_run._MirroredReplicaThread threads spawned from it will have
422    caching enabled.
423
424  Yields:
425    A context for caching variables.
426  """
427
428  try:
429    if caching_scope_local.in_caching_scope():
430      # There is nested cache scope, which is not supported.
431      raise ValueError("cache_variable_reads scope cannot be nested")
432    caching_scope_local.enter_scope()
433    yield
434  finally:
435    caching_scope_local.exit_scope()
436
437
438# The following mapping indicates the policy that you must use for a given
439# variable `synchronization` and `aggregation` pair.
440# OnWritePolicy is used for:
441# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
442# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
443# OnReadPolicy is used for:
444# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
445VARIABLE_POLICY_MAPPING = {
446    vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
447    vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
448}
449
450VARIABLE_CLASS_MAPPING = {
451    "VariableClass": values_lib.DistributedVariable,
452    vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
453    vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
454}
455
456TPU_VARIABLE_POLICY_MAPPING = {
457    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
458    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
459}
460
461TPU_VARIABLE_CLASS_MAPPING = {
462    "VariableClass": tpu_values_lib.TPUDistributedVariable,
463    vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
464    vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,
465}
466