xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/values.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
16
17import copy
18import weakref
19
20from tensorflow.python.distribute import device_util
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.distribute import distribution_strategy_context as ds_context
23from tensorflow.python.distribute import packed_distributed_variable as packed
24from tensorflow.python.distribute import reduce_util
25from tensorflow.python.distribute import values_util
26from tensorflow.python.eager import context
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.framework import type_spec
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import variable_scope as vs
36from tensorflow.python.ops import variables as variables_lib
37from tensorflow.python.trackable import base as trackable
38from tensorflow.python.training.saving import saveable_object
39from tensorflow.python.types import core
40from tensorflow.python.types import distribute as ds_types
41from tensorflow.python.types import trace
42
43
44def _on_write_update_replica(var, update_fn, value, **kwargs):
45  """Updates variables with ON_WRITE synchronization in replica context."""
46  if var.aggregation == vs.VariableAggregation.NONE:
47    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
48
49  if not ds_context.get_strategy().extended._use_merge_call():  # pylint: disable=protected-access
50    # Don't allow MEAN with non float dtype, since it may cause unexpected
51    # precision loss. Python3 and NumPy automatically upcast integers to
52    # float in division, but we should always preserve the type.
53    if var.aggregation == vs.VariableAggregation.MEAN and (
54        not var.dtype.is_floating) and tensor_util.is_tf_type(value):
55      raise ValueError(
56          "Cannot update non-float variables with "
57          "tf.VariableAggregation.MEAN aggregation in replica context. "
58          "Either change the variable dtype to float or update it in "
59          "cross-replica context.")
60
61    aggregated_value = apply_aggregation_replica_context(
62        value, var.aggregation, var)
63    values_util.mark_as_unsaveable()
64
65    return ds_context.get_replica_context()._update(  # pylint: disable=protected-access
66        var,
67        update_fn,
68        args=(aggregated_value,),
69        kwargs=kwargs,
70        group=True)
71
72  else:
73
74    def merge_fn(strategy, value, **kwargs):
75      """Aggregate values and update all variables in cross replica context."""
76      # Don't allow MEAN with non float dtype, since it may cause unexpected
77      # precision loss. Python3 and NumPy automatically upcast integers to
78      # float in division, but we should always preserve the type.
79      #
80      # Note that to be backward compatible we allow the case when the value
81      # is *always* the same on each replica. I.E. value is not a
82      # PerReplica. Refer to regroup() to see how values are grouped.
83      if var.aggregation == vs.VariableAggregation.MEAN and (
84          not var.dtype.is_floating) and isinstance(value, PerReplica):
85        raise ValueError(
86            "Cannot update non-float variables with "
87            "tf.VariableAggregation.MEAN aggregation in replica context. "
88            "Either change the variable dtype to float or update it in "
89            "cross-replica context.")
90
91      assert strategy == var.distribute_strategy
92      v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
93      return var._update_cross_replica(update_fn, v, **kwargs)  # pylint: disable=protected-access
94
95    return ds_context.get_replica_context().merge_call(
96        merge_fn, args=(value,), kwargs=kwargs)
97
98
99def apply_aggregation_replica_context(value, aggregation, destinations):
100  """Aggregate `value` to `destinations` as specified by `aggregation`."""
101  # if it is a python literal, return without aggregation
102  if isinstance(value, DistributedValues):
103    raise TypeError(
104        "Cannot use DistributedValues to update variables in replica context.")
105  if not tensor_util.is_tf_type(value):
106    return value
107
108  if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
109    # Switch to cross-replica context to broadcast
110    def merge_fn(strategy, value):
111      return strategy.extended.broadcast_to(
112          strategy.experimental_local_results(value)[0],
113          destinations=destinations)
114
115    return ds_context.get_replica_context().merge_call(merge_fn, args=(value,))
116
117  else:
118    reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
119    aggregated_value = ds_context.get_strategy(  # pylint: disable=protected-access
120    ).extended._replica_ctx_all_reduce(reduce_op, value)
121    return aggregated_value
122
123
124class DistributedValues(ds_types.DistributedValues):
125  """Base class for representing distributed values."""
126
127  def __init__(self, values):
128    """Should only be called by subclass __init__."""
129    self._values = tuple(values)
130
131  def _get(self):
132    """Returns the value for the current device or raises a ValueError."""
133    replica_id = values_util.get_current_replica_id_as_int()
134    if replica_id is None:
135      return self._get_cross_replica()
136    else:
137      return self._values[replica_id]
138
139  def _get_cross_replica(self):
140    raise NotImplementedError(
141        "DistributedValues._get_cross_replica should be implemented by "
142        "sub-classes which support cross-replica accesses.")
143
144  def _get_on_device_or_primary(self):
145    """Returns value in same replica or device if possible, else the _primary."""
146    replica_id = values_util.get_current_replica_id_as_int()
147    if replica_id is None:
148      # Try to find a value on the current device.
149      current_device = device_util.canonicalize(device_util.current())
150      for value in self._values:
151        if device_util.canonicalize(value.device) == current_device:
152          return value
153      return self._primary
154    else:
155      return self._values[replica_id]
156
157  @property
158  def _primary(self):
159    """Returns a representative component."""
160    return self._values[0]
161
162  @property
163  def _devices(self):
164    return tuple(v.device for v in self._values)
165
166  def __str__(self):
167    debug_str = ",\n".join(
168        "  %d: %s" % (i, v) for i, v in enumerate(self._values))
169    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
170
171  def __repr__(self):
172    debug_repr = ",\n".join(
173        "  %d: %r" % (i, v) for i, v in enumerate(self._values))
174    return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
175
176
177# NOTE(josh11b,apassos): It would be great if we could inspect the values this was
178# initialized with and use that to generate the overloaded operators here.
179# Unfortunately, Python's rules for special methods don't allow this, see
180# https://docs.python.org/3/reference/datamodel.html#special-method-names
181# "if a class defines a method named __getitem__(), and x is an instance of
182# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)."
183# In particular, these special methods don't go through __getattr__, and
184# it will only use those methods if they are defined in the class, not the
185# object.
186class DistributedDelegate(DistributedValues):
187  """A map from device to values; acts as the same type as the values."""
188
189  def __getattr__(self, name):
190    # The '_use_resource_variables' and the attrs starts with '_self' are used
191    # for restoring the saved_model proto, and '_attribute_sentinel' is used for
192    # Layer tracking. At the point these attrs are queried, the variable has not
193    # been initialized. Thus it should not query those of the underlying
194    # components.
195    if name.startswith("_self_") or name in ("_use_resource_variables",
196                                             "_attribute_sentinel",
197                                             "_distributed_container"):
198      return super(DistributedDelegate, self).__getattr__(name)
199
200    # This allows copy.copy(DistributedDelegate). When copying an object,
201    # copy.copy doesn't invoke its __init__ method, instead it makes a new
202    # empty object, then copies the attributes over. copy.copy looks for
203    # attributes like "__getstate__" in case the object implements its custom
204    # copying. Since DistributedDelegate doesn't have those attributes defined,
205    # __getattr__ will be invoked, which tries to access "_values" attributes,
206    # but that doesn't exist either because this is an empty object, and again
207    # __getattr__ is invoked, leading to an infinite recursion.
208    if name == "_values":
209      raise AttributeError()
210
211    # TODO(priyag): This needs to be made robust against pitfalls from mix use
212    # __getattr__ and @property. See b/120402273.
213    return getattr(self._get(), name)
214
215  @property
216  def values(self):
217    """Returns the per replica values."""
218    return self._values
219
220  def _get_as_operand(self):
221    """Returns the value for operations for the current device.
222
223    Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
224    value type within a replica context. They can, however, return a value that
225    can be used by the operations below.
226    """
227    return self._get()
228
229  # pylint: disable=multiple-statements
230  def __add__(self, o):
231    return self._get_as_operand() + o
232
233  def __radd__(self, o):
234    return o + self._get_as_operand()
235
236  def __sub__(self, o):
237    return self._get_as_operand() - o
238
239  def __rsub__(self, o):
240    return o - self._get_as_operand()
241
242  def __mul__(self, o):
243    return self._get_as_operand() * o
244
245  def __rmul__(self, o):
246    return o * self._get_as_operand()
247
248  def __truediv__(self, o):
249    return self._get_as_operand() / o
250
251  def __rtruediv__(self, o):
252    return o / self._get_as_operand()
253
254  def __floordiv__(self, o):
255    return self._get_as_operand() // o
256
257  def __rfloordiv__(self, o):
258    return o // self._get_as_operand()
259
260  def __mod__(self, o):
261    return self._get_as_operand() % o
262
263  def __rmod__(self, o):
264    return o % self._get_as_operand()
265
266  def __lt__(self, o):
267    return self._get_as_operand() < o
268
269  def __le__(self, o):
270    return self._get_as_operand() <= o
271
272  def __gt__(self, o):
273    return self._get_as_operand() > o
274
275  def __ge__(self, o):
276    return self._get_as_operand() >= o
277
278  def __and__(self, o):
279    return self._get_as_operand() & o
280
281  def __rand__(self, o):
282    return o & self._get_as_operand()
283
284  def __or__(self, o):
285    return self._get_as_operand() | o
286
287  def __ror__(self, o):
288    return o | self._get_as_operand()
289
290  def __xor__(self, o):
291    return self._get_as_operand() ^ o
292
293  def __rxor__(self, o):
294    return o ^ self._get_as_operand()
295
296  def __getitem__(self, o):
297    return self._get_as_operand()[o]
298
299  def __pow__(self, o, modulo=None):
300    return pow(self._get_as_operand(), o, modulo)
301
302  def __rpow__(self, o):
303    return pow(o, self._get_as_operand())
304
305  def __invert__(self):
306    return ~self._get_as_operand()
307
308  def __neg__(self):
309    return -self._get_as_operand()
310
311  def __abs__(self):
312    return abs(self._get_as_operand())
313
314  def __div__(self, o):
315    try:
316      return self._get_as_operand().__div__(o)
317    except AttributeError:
318      # See https://docs.python.org/3/library/constants.html#NotImplemented
319      return NotImplemented
320
321  def __rdiv__(self, o):
322    try:
323      return self._get_as_operand().__rdiv__(o)
324    except AttributeError:
325      # See https://docs.python.org/3/library/constants.html#NotImplemented
326      return NotImplemented
327
328  def __matmul__(self, o):
329    try:
330      return self._get_as_operand().__matmul__(o)
331    except AttributeError:
332      # See https://docs.python.org/3/library/constants.html#NotImplemented
333      return NotImplemented
334
335  def __rmatmul__(self, o):
336    try:
337      return self._get_as_operand().__rmatmul__(o)
338    except AttributeError:
339      # See https://docs.python.org/3/library/constants.html#NotImplemented
340      return NotImplemented
341
342  # TODO(josh11b): Even more operator overloads.
343
344
345class PerReplica(DistributedValues, composite_tensor.CompositeTensor,
346                 ds_types.PerReplica):
347  """Holds a map from replica to unsynchronized values."""
348
349  @property
350  def _type_spec(self):
351    return PerReplicaSpec(
352        *(type_spec.type_spec_from_value(v) for v in self._values))
353
354  @property
355  def values(self):
356    """Returns the per replica values."""
357    return self._values
358
359
360def _per_replica_to_tensor(var, dtype=None, name=None, as_ref=False):
361  """Converts a `PerReplica` to a `Tensor`."""
362  del name
363  if dtype is not None and not dtype.is_compatible_with(var.dtype):
364    raise ValueError(
365        "Incompatible type conversion requested to type {!r} for variable "
366        "of type {!r}".format(dtype.name, var.dtype.name))
367  if as_ref:
368    raise NotImplementedError(
369        "PerReplica doesn't support being used as a reference.")
370  if ds_context.in_cross_replica_context() or not ds_context.has_strategy():
371    raise ValueError("It looks like you are using a PerReplica object while "
372                     "not inside a replica context, which is not supported. "
373                     "Try running your op or function inside a replica context "
374                     "by using `strategy.run`")
375  else:
376    replica_id = values_util.get_current_replica_id_as_int()
377    return var.values[replica_id]
378
379# Register a conversion function to provide a useful error message when users
380# try to use PerReplica values in the wrong contexts
381ops.register_tensor_conversion_function(PerReplica, _per_replica_to_tensor)
382
383
384class PerReplicaSpec(type_spec.TypeSpec):
385  """Type specification for a `PerReplica`."""
386
387  __slots__ = ["_value_specs"]
388
389  value_type = property(lambda self: PerReplica)
390
391  def __init__(self, *value_specs):
392    self._value_specs = tuple(value_specs)
393
394  def _serialize(self):
395    return self._value_specs
396
397  @property
398  def _component_specs(self):
399    return self._value_specs
400
401  def _to_components(self, value):
402    replica_context = ds_context.get_replica_context()
403    if replica_context is not None and replica_context.num_replicas_in_sync > 1:
404      raise ValueError(
405          "Flattening a PerReplica to components is not supported in replica "
406          "context.")
407    return value._values  # pylint: disable=protected-access
408
409  def _from_components(self, tensor_list):
410    return PerReplica(tensor_list)
411
412
413# Note that unlike PerReplica, Mirrored values inherit from
414# DistributedDelegate and so can be used directly in cross-replica mode.
415# TODO(tomhennigan) Should this extend CompositeTensor?
416class Mirrored(DistributedDelegate, ds_types.Mirrored):
417  """Holds a map from replica to values which are kept in sync."""
418
419  def _get_cross_replica(self):
420    return self._get_on_device_or_primary()
421
422  def _as_graph_element(self):
423    obj = self._get()
424    conv_fn = getattr(obj, "_as_graph_element", None)
425    if conv_fn and callable(conv_fn):
426      return conv_fn()
427    return obj
428
429
430class DistributedVarOp(object):
431  """A class that looks like `tf.Operation`."""
432
433  def __init__(self, name, graph, traceback, typ):
434    self.name = name
435    self.graph = graph
436    self.traceback = traceback
437    self.type = typ
438
439  def __eq__(self, o):
440    if not isinstance(o, self.__class__):
441      raise NotImplementedError
442    return (self.name == o.name and self.graph == o.graph and
443            self.traceback == o.traceback and self.type == o.type)
444
445  def __hash__(self):
446    return hash((self.name, self.graph, tuple(self.traceback), self.type))
447
448
449# TODO(b/209081027): Remove this once Variable is a CompositeTensor.
450class DistributedVariableTraceType(trace.TraceType):
451  """TraceType of DistributedVariable objects."""
452
453  def __init__(self, distributed_variable):
454    self.distributed_variable = distributed_variable
455    self.components = (tuple(distributed_variable.shape.as_list()),
456                       distributed_variable.dtype)
457
458  def is_subtype_of(self, other):
459    return self == other
460
461  def most_specific_common_supertype(self, others):
462    return self if all(self == other for other in others) else None
463
464  def _placeholder_value(self):
465    return self.distributed_variable
466
467  def __hash__(self) -> int:
468    return hash(self.components)
469
470  def __eq__(self, other) -> bool:
471    if not isinstance(other, DistributedVariableTraceType):
472      return False
473
474    return self.components == other.components
475
476
477class DistributedVariable(DistributedDelegate, variables_lib.Variable,
478                          core.Tensor):
479  """Holds a map from replica to variables."""
480
481  def __init__(self, strategy, values, aggregation, var_policy=None):
482    if (aggregation == variables_lib.VariableAggregation.MEAN and
483        not values[0].dtype.is_floating):
484      raise ValueError(
485          "creating distributed tf.Variable with aggregation=MEAN and a "
486          "non-floating dtype is not supported, please use a different "
487          "aggregation or dtype")
488    self._distribute_strategy = strategy
489    self._aggregation = aggregation
490    super(DistributedVariable, self).__init__(values)
491    self._common_name = self._primary.name.split(":")[0]
492    # Use a weakref to make it easy to map from the contained values
493    # to the container without introducing a reference cycle.
494    for v in values:
495      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access
496
497    # Packed variable is used to reduce the overhead of function execution.
498    # For a DistributedVariable, only one variable handle is captured into a
499    # function graph. It's only supported in eager mode.
500    if ops.executing_eagerly_outside_functions() and getattr(
501        strategy, "_enable_packed_variable_in_eager_mode", False):
502      name = "%s/packed/" % self._common_name
503      self._packed_var = packed.PackedDistributedVariable(values, name=name)
504    else:
505      self._packed_var = None
506
507    # tf.keras keeps track of variables initialized using this attribute. When
508    # tf.keras gets the default session, it initializes all uninitialized vars.
509    # We need to make _keras_initialized a member of DistributedVariable because
510    # without this it will use `__getattr__` which will delegate to a component
511    # variable.
512    self._keras_initialized = False
513    # Typically, a `DistributedVariable`'s initializer is composed of the
514    # initializers of the components variables. However, in some cases, such as
515    # when restoring from a checkpoint, we may set the _initializer_op
516    # property on the entire `DistributedVariable`.
517    self._initializer_op = None
518    # Set a VariablePolicy which decides how we replicate/aggregate the given
519    # variable.
520    self._policy = var_policy
521
522  def __deepcopy__(self, memo):
523    """Perform a deepcopy of the `DistributedVariable`.
524
525    Unlike the deepcopy of a regular tf.Variable, this keeps the original
526    strategy and devices of the `DistributedVariable`.  To avoid confusion
527    with the behavior of deepcopy on a regular `Variable` (which does
528    copy into new devices), we only allow a deepcopy of a `DistributedVariable`
529    within its originating strategy scope.
530
531    Args:
532      memo: The memoization object for `deepcopy`.
533
534    Returns:
535      A deep copy of the current `DistributedVariable`.
536
537    Raises:
538      RuntimeError: If trying to deepcopy into a different strategy.
539    """
540    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
541      new_values = []
542
543      for value in self._values:
544        with ops.device(value.device):
545          new_values.append(copy.deepcopy(value, memo))
546
547    copied_variable = type(self)(
548        strategy=self._distribute_strategy,
549        values=new_values,
550        aggregation=self._aggregation,
551        var_policy=copy.deepcopy(self._policy, memo))
552
553    memo[id(self)] = copied_variable
554
555    return copied_variable
556
557  def _use_packed_variable(self):
558    # Don't use packed variable when under a SaveContext to avoid explicit
559    # device placement on variable consuming ops.
560    return self._packed_var is not None and (
561        not values_util.is_saving_non_distributed())
562
563  def is_initialized(self, name=None):
564    """Identifies if all the component variables are initialized.
565
566    Args:
567      name: Name of the final `logical_and` op.
568
569    Returns:
570      The op that evaluates to True or False depending on if all the
571      component variables are initialized.
572    """
573    if values_util.is_saving_non_distributed():
574      return self._primary.is_initialized()
575    if self._use_packed_variable():
576      return self._packed_var.is_initialized()
577    result = self._primary.is_initialized()
578    # We iterate through the list of values except the last one to allow us to
579    # name the final `logical_and` op the same name that is passed by the user
580    # to the `is_initialized` op. For distributed variables, the
581    # `is_initialized` op is a `logical_and` op.
582    for v in self._values[1:-1]:
583      result = math_ops.logical_and(result, v.is_initialized())
584    result = math_ops.logical_and(
585        result, self._values[-1].is_initialized(), name=name)
586    return result
587
588  @property
589  def initializer(self):
590    if values_util.is_saving_non_distributed():
591      return self._primary.initializer
592    if self._initializer_op:
593      init_op = self._initializer_op
594    else:
595      # return grouped ops of all the var initializations of component values of
596      # the mirrored variable
597      init_op = control_flow_ops.group(
598          tuple(v.initializer for v in self._values))
599    return init_op
600
601  def initialized_value(self):
602    return self._get_on_device_or_primary().initialized_value()
603
604  @property
605  def initial_value(self):
606    return self._get_on_device_or_primary().initial_value
607
608  @property
609  def constraint(self):
610    return self._primary.constraint
611
612  @property
613  def graph(self):
614    return self._primary.graph
615
616  @property
617  def _shared_name(self):
618    return self._common_name
619
620  @property
621  def _unique_id(self):
622    return self._primary._unique_id  # pylint: disable=protected-access
623
624  @property
625  def _graph_key(self):
626    """Lets Optimizers know which graph this variable is from."""
627    return self._primary._graph_key  # pylint: disable=protected-access
628
629  @property
630  def name(self):
631    return self._primary.name
632
633  @property
634  def dtype(self):
635    return self._primary.dtype
636
637  @property
638  def shape(self):
639    return self._primary.shape
640
641  @property
642  def synchronization(self):
643    return self._primary.synchronization
644
645  @property
646  def aggregation(self):
647    return self._aggregation
648
649  @property
650  def _packed_variable(self):
651    if self._use_packed_variable():
652      return self._packed_var
653    return None
654
655  @property
656  def handle(self):
657    if values_util.is_saving_non_distributed():
658      return self._primary.handle
659    replica_id = values_util.get_current_replica_id_as_int()
660    if replica_id is None:
661      raise ValueError(
662          "DistributedVariable.handle is not available outside the replica "
663          "context or a `tf.distribute.Strategy.update()` call.")
664    else:
665      if self._use_packed_variable():
666        return self._packed_var.handle
667      return self._values[replica_id].handle
668
669  def eval(self, session=None):
670    return self._get_on_device_or_primary().eval(session)
671
672  @property
673  def _save_slice_info(self):
674    return self._primary._save_slice_info  # pylint: disable=protected-access
675
676  def _get_save_slice_info(self):
677    return self._primary._get_save_slice_info()  # pylint: disable=protected-access
678
679  def _set_save_slice_info(self, save_slice_info):
680    for v in self._values:
681      v._set_save_slice_info(save_slice_info)  # pylint: disable=protected-access
682
683  @property
684  def device(self):
685    return self._get_on_device_or_primary().device
686
687  @property
688  def trainable(self):
689    return self._primary.trainable
690
691  @property
692  def distribute_strategy(self):
693    return self._distribute_strategy
694
695  def get_shape(self):
696    return self._primary.get_shape()
697
698  def to_proto(self, export_scope=None):
699    return self._primary.to_proto(export_scope=export_scope)
700
701  @property
702  def op(self):
703    if values_util.is_saving_non_distributed():
704      return self._primary.op
705    # We want cross-replica code that does some var.op.X calls
706    # to work (even if the current device isn't in self._devices), but
707    # other uses of var.op in a cross-replica context to fail.
708    if ds_context.in_cross_replica_context():
709      return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
710                              self._primary.op.traceback, self._primary.op.type)
711    return self._get().op
712
713  @property
714  def _in_graph_mode(self):
715    return self._primary._in_graph_mode  # pylint: disable=protected-access
716
717  def _get_replica(self, replica_id):
718    """Returns the value on a device with the given replica_id."""
719    if self._use_packed_variable():
720      return self._packed_var.on_device(self._devices[replica_id])
721    return self._values[replica_id]
722
723  def _get(self):
724    """Returns the value for the current device or raises a ValueError."""
725    if values_util.is_saving_non_distributed():
726      return self._primary
727    replica_id = values_util.get_current_replica_id_as_int()
728    if replica_id is None:
729      return self._get_cross_replica()
730    else:
731      return self._get_replica(replica_id)
732
733  def _get_on_device_or_primary(self):
734    """Returns value in same replica or device if possible, else the _primary."""
735    if values_util.is_saving_non_distributed():
736      return self._primary
737    replica_id = values_util.get_current_replica_id_as_int()
738    if replica_id is None:
739      # Try to find a value on the current device.
740      current_device = device_util.canonicalize(device_util.current())
741      for i, value in enumerate(self._values):
742        if device_util.canonicalize(value.device) == current_device:
743          return self._get_replica(i)
744      return self._get_replica(0)
745    else:
746      return self._get_replica(replica_id)
747
748  def read_value(self):
749    if values_util.is_saving_non_distributed():
750      return self._primary.read_value()
751    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
752      return array_ops.identity(self._get())
753
754  def value(self):
755    if values_util.is_saving_non_distributed():
756      return self._primary.value()
757    if self._policy:
758      return self._policy.value(self)
759    return self._get_on_device_or_primary().value()
760
761  def numpy(self):
762    if context.executing_eagerly():
763      return self.read_value().numpy()
764    else:
765      raise NotImplementedError("DistributedVariable.numpy() is only available "
766                                "when eager execution is enabled.")
767
768  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
769    if values_util.is_saving_non_distributed():
770      return self._primary.assign_sub(value, use_locking, name, read_value)
771    if self._policy:
772      return self._policy.assign_sub(
773          self,
774          value,
775          use_locking=use_locking,
776          name=name,
777          read_value=read_value)
778    return values_util.on_write_assign_sub(
779        self, value, use_locking=use_locking, name=name, read_value=read_value)
780
781  def assign_add(self, value, use_locking=False, name=None, read_value=True):
782    if values_util.is_saving_non_distributed():
783      return self._primary.assign_add(value, use_locking, name, read_value)
784    if self._policy:
785      return self._policy.assign_add(
786          self,
787          value,
788          use_locking=use_locking,
789          name=name,
790          read_value=read_value)
791    return values_util.on_write_assign_add(
792        self, value, use_locking=use_locking, name=name, read_value=read_value)
793
794  def assign(self, value, use_locking=False, name=None, read_value=True):
795    if values_util.is_saving_non_distributed():
796      return self._primary.assign(value, use_locking, name, read_value)
797    if self._policy:
798      return self._policy.assign(
799          self,
800          value,
801          use_locking=use_locking,
802          name=name,
803          read_value=read_value)
804    return values_util.on_write_assign(
805        self, value, use_locking=use_locking, name=name, read_value=read_value)
806
807  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
808    if values_util.is_saving_non_distributed():
809      return self._primary.scatter_sub(sparse_delta, use_locking, name)
810    if self._policy:
811      return self._policy.scatter_sub(
812          self, sparse_delta, use_locking=use_locking, name=name)
813    return values_util.scatter_sub(
814        self, sparse_delta, use_locking=use_locking, name=name)
815
816  def scatter_add(self, sparse_delta, use_locking=False, name=None):
817    if values_util.is_saving_non_distributed():
818      return self._primary.scatter_add(sparse_delta, use_locking, name)
819    if self._policy:
820      return self._policy.scatter_add(
821          self, sparse_delta, use_locking=use_locking, name=name)
822    return values_util.scatter_add(
823        self, sparse_delta, use_locking=use_locking, name=name)
824
825  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
826    if values_util.is_saving_non_distributed():
827      return self._primary.scatter_mul(sparse_delta, use_locking, name)
828    if self._policy:
829      return self._policy.scatter_mul(
830          self, sparse_delta, use_locking=use_locking, name=name)
831    return values_util.scatter_mul(
832        self, sparse_delta, use_locking=use_locking, name=name)
833
834  def scatter_div(self, sparse_delta, use_locking=False, name=None):
835    if values_util.is_saving_non_distributed():
836      return self._primary.scatter_div(sparse_delta, use_locking, name)
837    if self._policy:
838      return self._policy.scatter_div(
839          self, sparse_delta, use_locking=use_locking, name=name)
840    return values_util.scatter_div(
841        self, sparse_delta, use_locking=use_locking, name=name)
842
843  def scatter_min(self, sparse_delta, use_locking=False, name=None):
844    if values_util.is_saving_non_distributed():
845      return self._primary.scatter_min(sparse_delta, use_locking, name)
846    if self._policy:
847      return self._policy.scatter_min(
848          self, sparse_delta, use_locking=use_locking, name=name)
849    return values_util.scatter_min(
850        self, sparse_delta, use_locking=use_locking, name=name)
851
852  def scatter_max(self, sparse_delta, use_locking=False, name=None):
853    if values_util.is_saving_non_distributed():
854      return self._primary.scatter_max(sparse_delta, use_locking, name)
855    if self._policy:
856      return self._policy.scatter_max(
857          self, sparse_delta, use_locking=use_locking, name=name)
858    return values_util.scatter_max(
859        self, sparse_delta, use_locking=use_locking, name=name)
860
861  def scatter_update(self, sparse_delta, use_locking=False, name=None):
862    if values_util.is_saving_non_distributed():
863      return self._primary.scatter_update(sparse_delta, use_locking, name)
864    if self._policy:
865      return self._policy.scatter_update(
866          self, sparse_delta, use_locking=use_locking, name=name)
867    return values_util.scatter_update(
868        self, sparse_delta, use_locking=use_locking, name=name)
869
870  def __tf_tracing_type__(self, _):
871    return DistributedVariableTraceType(self)
872
873  def _gather_saveables_for_checkpoint(self):
874    """Overrides Trackable method.
875
876    This allows both name-based and object-based save and restore of
877    DistributedVariables.
878
879    Returns:
880      A dictionary mapping attribute names to `SaveableObject` factories.
881    """
882
883    def _saveable_factory(name=self._common_name):
884      return _DistributedVariableSaveable(self, self._primary, name)
885
886    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
887
888  def _as_graph_element(self):
889    if values_util.is_saving_non_distributed():
890      return self._primary._as_graph_element()  # pylint: disable=protected-access
891    if self._policy:
892      return self._policy._as_graph_element(self)  # pylint: disable=protected-access
893
894    raise NotImplementedError(
895        "DistributedVariable._as_graph_element requires a valid "
896        "VariablePolicy. Please set the policy via the `var_policy` argument "
897        "in the constructor, or override this method in sub-classes which "
898        "support cross-replica accesses.")
899
900  def _get_cross_replica(self):
901    if values_util.is_saving_non_distributed():
902      return self._primary
903    if self._policy:
904      return self._policy._get_cross_replica(self)  # pylint: disable=protected-access
905
906    raise NotImplementedError(
907        "DistributedVariable._get_cross_replica requires a valid "
908        "VariablePolicy. Please set the policy via the `var_policy` argument "
909        "in the constructor, or override this method in sub-classes which "
910        "support cross-replica accesses.")
911
912  def _update_cross_replica(self, update_fn, value, **kwargs):
913    """Applies updates across replicas.
914
915    Args:
916      update_fn: A callable to pass to `strategy.extended.update` to update the
917        variable. It should has the same signature as `Variable.assign()`.
918      value: value to be passed to `update_fn`.
919      **kwargs: remaining arguments to `update_fn`.
920
921    Returns:
922      Updated variable or `tf.Operation`.
923    """
924    values_util.mark_as_unsaveable()
925    return self.distribute_strategy.extended.update(
926        self, update_fn, args=(value,), kwargs=kwargs, group=True)
927
928  def _update_replica(self, update_fn, value, **kwargs):
929    """Applies updates in one replica.
930
931    Args:
932      update_fn: A callable to update the variable. It should has the same
933        signature as `Variable.assign()`.
934      value: value to be passed to `update_fn`.
935      **kwargs: remaining arguments to `update_fn`.
936
937    Returns:
938      Updated variable or `tf.Operation`.
939    """
940    if self._policy:
941      return self._policy._update_replica(self, update_fn, value, **kwargs)  # pylint: disable=protected-access
942    raise NotImplementedError(
943        "DistributedVariable._update_replica requires a valid VariablePolicy. "
944        "Please set the policy via the `var_policy` argument in the "
945        "constructor, or override this method in sub-classes which support "
946        "cross-replica accesses.")
947
948  def _update(self, update_fn, value, **kwargs):
949    """Applies updates depending on the context.
950
951    The method calls `_update_replica` in replica context,
952    `_update_cross_replica` in cross replica context, and `update_fn` in update
953    context.
954
955    If `read_value` is True, the method returns the updated Variable. If
956    `read_value` is False, the method returns the update `tf.Operation`.
957
958    Args:
959      update_fn: A callable to pass to `strategy.extended.update` to update the
960        variable. It should have the same signature as `Variable.assign()`.
961      value: value to be passed to `update_fn`.
962      **kwargs: keyword arguments to `update_fn`.
963
964    Returns:
965      Updated variable or `tf.Operation`.
966
967    """
968    if values_util.is_saving_non_distributed():
969      return update_fn(self._primary, value, **kwargs)
970    with ds_context.enter_or_assert_strategy(self.distribute_strategy):
971      if ds_context.in_cross_replica_context():
972        update_replica_id = distribute_lib.get_update_replica_id()
973        if update_replica_id is not None:
974          replica_value = self._get_replica(update_replica_id)
975          return update_fn(replica_value, value, **kwargs)
976        return self._update_cross_replica(update_fn, value, **kwargs)
977      else:
978        values_util.assert_replica_context(self.distribute_strategy)
979        return self._update_replica(update_fn, value, **kwargs)
980
981  def _should_act_as_resource_variable(self):
982    """Pass resource_variable_ops.is_resource_variable check."""
983    pass
984
985  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
986    """Converts a variable to a tensor."""
987    if values_util.is_saving_non_distributed():
988      return ops.convert_to_tensor(
989          self._primary, dtype=dtype, name=name, as_ref=as_ref)
990    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
991      return ops.convert_to_tensor(
992          self._get(), dtype=dtype, name=name, as_ref=as_ref)
993
994  def _map_resources(self, save_options):
995    """For implementing `Trackable`."""
996    # Initialize for self._primary first, so that obj_map[self._primary] and
997    # resource_map[self._primary.handle] contain mapped values.
998    obj_map, resource_map = self._primary._map_resources(save_options)  # pylint:disable=protected-access
999    for v in [v for v in self._values if v != self._primary]:
1000
1001      if (save_options.experimental_variable_policy  # pylint:disable=protected-access
1002          ._expand_distributed_variables()):
1003        v_obj_map, v_resource_map = v._map_resources(save_options)  # pylint:disable=protected-access
1004        obj_map.update(v_obj_map)
1005        resource_map.update(v_resource_map)
1006      else:
1007        obj_map[v] = obj_map[self._primary]
1008        resource_map[v.handle] = resource_map[self._primary.handle]
1009    obj_map[self] = obj_map[self._primary]
1010    resource_map[self] = resource_map[self._primary.handle]
1011    if self._packed_var is not None:
1012      resource_map[self._packed_var.packed_handle] = resource_map[
1013          self._primary.handle]
1014    return obj_map, resource_map
1015
1016  def _write_object_proto(self, proto, options):
1017    """Update a SavedObject proto for the caller.
1018
1019    If a DistributedVariable object supports this method, it will be called when
1020    saving with a pre-built `SavedObject` proto representing the object, plus an
1021    instance of `SaveOptions`. This method is then free to modify that proto
1022    instance.
1023
1024    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1025    write out information about their components to the
1026    `experimental_distributed_variable_components` field of a
1027    `SavedVariable` (depending on the `SaveOptions` variable policy).
1028
1029    Args:
1030      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1031        will be a `SavedVariable` instance.
1032      options: A `SaveOptions` instance.
1033    """
1034    resource_variable_ops.write_object_proto_for_resource_variable(
1035        self, proto, options)
1036    if self._policy:
1037      if self._policy._is_mirrored():  # pylint: disable=protected-access
1038        self._policy._write_object_proto(self, proto, options)  # pylint: disable=protected-access
1039
1040  @property
1041  def is_distributed_variable(self):
1042    return True
1043
1044  def __tf_experimental_restore_capture__(
1045      self, concrete_function, internal_capture):
1046    concrete_function.graph.capture_distributed_variable(self, internal_capture)
1047    return self
1048
1049
1050# We extend from `saveable_object.SaveableObject` instead of
1051# `saveable_object_util.ResourceVariableSaveable` since we need to read the
1052# value of ONREAD variables when saving. `SaveableObject` provides a way to
1053# specify the function to run to get the value of the variable or tensor at
1054# saving time. We can use this for both ON_READ and ON_WRITE variables.
1055# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic
1056# if possible.
1057class _DistributedVariableSaveable(saveable_object.SaveableObject):
1058  """Class for defining how to restore a DistributedVariable."""
1059
1060  def __init__(self, distributed_variable, primary_variable, name):
1061    self._distributed_variable = distributed_variable
1062    if not self._distributed_variable._policy:
1063      raise ValueError(
1064          "The VariablePolicy of the argument `distributed_variable` must be "
1065          "set to create a _DistributedVariableSaveable. Please set it via "
1066          "the `var_policy` argument in the constructor of DistributedVariable."
1067      )
1068    tensor, spec = distributed_variable._policy.get_saveable(
1069        distributed_variable, primary_variable, name)
1070    super(_DistributedVariableSaveable, self).__init__(tensor, spec, name)
1071
1072  def restore(self, restored_tensors, restored_shapes):
1073    """Restore the same value into all variables."""
1074    tensor, = restored_tensors
1075    return self._distributed_variable._policy.get_restore_ops(  # pylint: disable=protected-access
1076        self._distributed_variable, tensor)
1077
1078
1079class _MirroredSaveable(saveable_object.SaveableObject):
1080  """Class for defining how to restore a MirroredVariable."""
1081
1082  def __init__(self, mirrored_variable, primary_variable, name):
1083    self._mirrored_variable = mirrored_variable
1084    tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
1085                                                     primary_variable, name)
1086    super(_MirroredSaveable, self).__init__(tensor, spec, name)
1087
1088  def restore(self, restored_tensors, restored_shapes):
1089    """Restore the same value into all variables."""
1090    tensor, = restored_tensors
1091    return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)
1092
1093
1094class MirroredVariable(DistributedVariable, Mirrored):
1095  """Holds a map from replica to variables whose values are kept in sync."""
1096
1097  def _update_replica(self, update_fn, value, **kwargs):
1098    return _on_write_update_replica(self, update_fn, value, **kwargs)
1099
1100  def scatter_min(self, *args, **kwargs):
1101    if values_util.is_saving_non_distributed():
1102      return self._primary.scatter_min(*args, **kwargs)
1103    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1104        self._aggregation != vs.VariableAggregation.NONE):
1105      raise NotImplementedError(
1106          values_util.scatter_error_msg.format(
1107              op_name="scatter_min", aggregation=self._aggregation))
1108    return super(MirroredVariable, self).scatter_min(*args, **kwargs)
1109
1110  def scatter_max(self, *args, **kwargs):
1111    if values_util.is_saving_non_distributed():
1112      return self._primary.scatter_max(*args, **kwargs)
1113    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1114        self._aggregation != vs.VariableAggregation.NONE):
1115      raise NotImplementedError(
1116          values_util.scatter_error_msg.format(
1117              op_name="scatter_max", aggregation=self._aggregation))
1118    return super(MirroredVariable, self).scatter_max(*args, **kwargs)
1119
1120  def scatter_update(self, *args, **kwargs):
1121    if values_util.is_saving_non_distributed():
1122      return self._primary.scatter_update(*args, **kwargs)
1123    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1124        self._aggregation != vs.VariableAggregation.NONE):
1125      raise NotImplementedError(
1126          values_util.scatter_error_msg.format(
1127              op_name="scatter_update", aggregation=self._aggregation))
1128    return super(MirroredVariable, self).scatter_update(*args, **kwargs)
1129
1130  def _get_cross_replica(self):
1131    # Return identity, to avoid directly exposing the variable to the user and
1132    # allowing it to be modified by mistake.
1133    return array_ops.identity(Mirrored._get_cross_replica(self))
1134
1135  def _as_graph_element(self):
1136    return self._get_on_device_or_primary()._as_graph_element()  # pylint: disable=protected-access
1137
1138  def _gather_saveables_for_checkpoint(self):
1139    """Overrides Trackable method.
1140
1141    This allows both name-based and object-based save and restore of
1142    MirroredVariables.
1143
1144    Returns:
1145      A dictionary mapping attribute names to `SaveableObject` factories.
1146    """
1147
1148    def _saveable_factory(name=self._common_name):
1149      return _MirroredSaveable(self, self._primary, name)
1150
1151    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1152
1153  def _write_object_proto(self, proto, options):
1154    """Update a SavedObject proto for the caller.
1155
1156    If a DistributedVariable object supports this method, it will be called when
1157    saving with a pre-built `SavedObject` proto representing the object, plus an
1158    instance of `SaveOptions`. This method is then free to modify that proto
1159    instance.
1160
1161    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1162    write out information about their components to the
1163    `experimental_distributed_variable_components` field of a
1164    `SavedVariable` (depending on the `SaveOptions` variable policy).
1165
1166    Args:
1167      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1168        will be a `SavedVariable` instance.
1169      options: A `SaveOptions` instance.
1170    """
1171    super(MirroredVariable, self)._write_object_proto(proto, options)
1172    values_util.write_object_proto(self, proto, options)
1173
1174  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1175    """Converts a variable to a tensor."""
1176    # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ
1177    # and ON_WRITE.
1178    # Try to avoid assignments to and other mutations of MirroredVariable
1179    # state except through a DistributionStrategy.extended.update() or any of
1180    # the `assign*` and `scatter*` calls.
1181    if as_ref:
1182      # A TF 1.x case where the variable is a boolean variable and used like:
1183      # tf.cond(v, true_fn, false_fn).
1184      raise ValueError(
1185          "You may be using variable created under distribute strategy in TF "
1186          "1.x control flows. Try explicitly converting the variable to Tensor "
1187          "using variable.read_value(), or switch to TF 2.x.")
1188    return ops.convert_to_tensor(
1189        self._get(), dtype=dtype, name=name, as_ref=as_ref)
1190
1191
1192class _SyncOnReadSaveable(saveable_object.SaveableObject):
1193  """Class for defining how to restore a SyncOnReadVariable."""
1194
1195  def __init__(self, sync_on_read_variable, name):
1196    self._sync_on_read_variable = sync_on_read_variable
1197    tensor, spec = values_util.get_on_read_saveable(
1198        sync_on_read_variable, sync_on_read_variable._primary, name)
1199
1200    super(_SyncOnReadSaveable, self).__init__(tensor, spec, name)
1201
1202  def restore(self, restored_tensors, restored_shapes):
1203    """Restore the same value into all variables."""
1204    tensor, = restored_tensors
1205    return values_util.get_on_read_restore_ops(
1206        self._sync_on_read_variable, tensor,
1207        self._sync_on_read_variable.aggregation)
1208
1209
1210class SyncOnReadVariable(DistributedVariable):
1211  """Holds a map from replica to variables whose values are reduced on save."""
1212
1213  def _update_replica(self, update_fn, value, **kwargs):
1214    return update_fn(self._get_on_device_or_primary(), value, **kwargs)
1215
1216  def _get(self):
1217    """Returns the value of SyncOnReadVariable based on surrounding context.
1218
1219    If called under a non-default replica-context, returns the corresponding
1220    variable on that replica.
1221    If called under default replica-context or cross-replica context, returns
1222    the synced value.
1223    """
1224    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1225      return super(SyncOnReadVariable, self)._get()
1226
1227  # TODO(b/154017756): Make assign behaivor in cross replica context consistent
1228  # with MirroredVariable.
1229  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
1230    if values_util.is_saving_non_distributed():
1231      return self._primary.assign_sub(value, use_locking, name, read_value)
1232    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1233      if (ds_context.in_cross_replica_context() and
1234          not values_util.in_replica_update_context()):
1235        values_util.mark_as_unsaveable()
1236        return values_util.on_read_assign_sub_cross_replica(
1237            self, value, read_value=read_value)
1238      else:
1239        return super(SyncOnReadVariable,
1240                     self).assign_sub(value, use_locking, name, read_value)
1241
1242  def assign_add(self, value, use_locking=False, name=None, read_value=True):
1243    if values_util.is_saving_non_distributed():
1244      return self._primary.assign_add(value, use_locking, name, read_value)
1245    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1246      if (ds_context.in_cross_replica_context() and
1247          not values_util.in_replica_update_context()):
1248        values_util.mark_as_unsaveable()
1249        return values_util.on_read_assign_add_cross_replica(
1250            self, value, read_value=read_value)
1251      else:
1252        return super(SyncOnReadVariable,
1253                     self).assign_add(value, use_locking, name, read_value)
1254
1255  def assign(self, value, use_locking=False, name=None, read_value=True):
1256    if values_util.is_saving_non_distributed():
1257      return self._primary.assign(value, use_locking, name, read_value)
1258    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1259      if (ds_context.in_cross_replica_context() and
1260          not values_util.in_replica_update_context()):
1261        values_util.mark_as_unsaveable()
1262        return values_util.on_read_assign_cross_replica(
1263            self, value, read_value=read_value)
1264      else:
1265        return super(SyncOnReadVariable, self).assign(value, use_locking, name,
1266                                                      read_value)
1267
1268  def _scatter_not_implemented(self, method):
1269    raise NotImplementedError(
1270        f"Variables with `synchronization=ON_READ` doesn't support `{method}`")
1271
1272  def scatter_sub(self, *args, **kwargs):
1273    if values_util.is_saving_non_distributed():
1274      return self._primary.scatter_sub(*args, **kwargs)
1275    self._scatter_not_implemented("scatter_sub")
1276
1277  def scatter_add(self, *args, **kwargs):
1278    if values_util.is_saving_non_distributed():
1279      return self._primary.scatter_add(*args, **kwargs)
1280    self._scatter_not_implemented("scatter_add")
1281
1282  def scatter_mul(self, *args, **kwargs):
1283    if values_util.is_saving_non_distributed():
1284      return self._primary.scatter_mul(*args, **kwargs)
1285    self._scatter_not_implemented("scatter_mul")
1286
1287  def scatter_div(self, *args, **kwargs):
1288    if values_util.is_saving_non_distributed():
1289      return self._primary.scatter_div(*args, **kwargs)
1290    self._scatter_not_implemented("scatter_div")
1291
1292  def scatter_min(self, *args, **kwargs):
1293    if values_util.is_saving_non_distributed():
1294      return self._primary.scatter_min(*args, **kwargs)
1295    self._scatter_not_implemented("scatter_min")
1296
1297  def scatter_max(self, *args, **kwargs):
1298    if values_util.is_saving_non_distributed():
1299      return self._primary.scatter_max(*args, **kwargs)
1300    self._scatter_not_implemented("scatter_max")
1301
1302  def scatter_update(self, *args, **kwargs):
1303    if values_util.is_saving_non_distributed():
1304      return self._primary.scatter_update(*args, **kwargs)
1305    self._scatter_not_implemented("scatter_update")
1306
1307  def value(self):
1308    if ds_context.in_variable_sync_on_read_context():
1309      raise NotImplementedError(
1310          "call `variable.value()` inside variable_sync_on_read_context is not "
1311          "supported")
1312    if values_util.is_saving_non_distributed():
1313      return self._primary.value()
1314    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1315      if (ds_context.in_cross_replica_context() and
1316          not values_util.in_replica_update_context()):
1317        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1318          return self._get_replica(0).value()
1319        return self._get_cross_replica()
1320      else:
1321        # _get_on_device_or_primary() returns a Variable.
1322        return self._get_on_device_or_primary().value()
1323
1324  def read_value(self):
1325    if ds_context.in_variable_sync_on_read_context():
1326      raise NotImplementedError(
1327          "call `variable.read_value()` inside variable_sync_on_read_context is"
1328          " not supported")
1329    return super().read_value()
1330
1331  def _get_cross_replica(self):
1332    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1333      # Consider returning a tensor value here to make the return value of
1334      # _get_cross_replica consistent.
1335      return self._get_replica(0)
1336    if self._aggregation == vs.VariableAggregation.SUM:
1337      values_util.mark_as_unsaveable()
1338    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1339      return self._distribute_strategy.reduce(
1340          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1341          self,
1342          axis=None)
1343
1344  def _as_graph_element(self):
1345    if values_util.is_saving_non_distributed():
1346      return self._primary._as_graph_element()  # pylint: disable=protected-access
1347    # pylint: disable=protected-access
1348    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1349      if ds_context.in_cross_replica_context():
1350        return ops.convert_to_tensor(self._get_cross_replica())
1351    return self._get()._as_graph_element()
1352
1353  def _gather_saveables_for_checkpoint(self):
1354    """Overrides Trackable method.
1355
1356    This allows both name-based and object-based save and restore of
1357    `SyncOnReadVariable`s.
1358
1359    Returns:
1360      A dictionary mapping attribute names to `SaveableObject` factories.
1361    """
1362
1363    def _saveable_factory(name=self._common_name):
1364      return _SyncOnReadSaveable(self, name)
1365
1366    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1367
1368  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1369    """Converts a SyncOnReadVariable to a tensor."""
1370    if values_util.is_saving_non_distributed():
1371      return ops.convert_to_tensor(
1372          self._primary, dtype=dtype, name=name, as_ref=as_ref)
1373    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
1374      replica_context = ds_context.get_replica_context()
1375      if (replica_context is not None and
1376          ds_context.in_variable_sync_on_read_context()):
1377        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1378          return ops.convert_to_tensor(
1379              self._get_replica(0), dtype=dtype, name=name, as_ref=as_ref)
1380        if self._aggregation == vs.VariableAggregation.SUM:
1381          values_util.mark_as_unsaveable()
1382        # pylint: disable=protected-access
1383        reduced = (
1384            replica_context.strategy.extended._replica_ctx_all_reduce(
1385                reduce_util.ReduceOp.from_variable_aggregation(
1386                    self._aggregation),
1387                self._get().read_value()))
1388        return ops.convert_to_tensor(
1389            reduced, dtype=dtype, name=name, as_ref=as_ref)
1390
1391      return ops.convert_to_tensor(
1392          self._get(), dtype=dtype, name=name, as_ref=as_ref)
1393
1394
1395# Register a conversion functions which reads the value of the variable,
1396# allowing instances of the class to be used as tensors.
1397# DistributedVariable
1398def _tensor_conversion_distributed_var(var,
1399                                       dtype=None,
1400                                       name=None,
1401                                       as_ref=False):
1402  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1403
1404
1405ops.register_tensor_conversion_function(DistributedVariable,
1406                                        _tensor_conversion_distributed_var)
1407
1408
1409# MirroredVariables
1410def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
1411  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1412
1413
1414ops.register_tensor_conversion_function(MirroredVariable,
1415                                        _tensor_conversion_mirrored)
1416
1417
1418# Mirrored Values
1419def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False):
1420  return ops.convert_to_tensor(
1421      value._get(), dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1422
1423
1424ops.register_tensor_conversion_function(Mirrored,
1425                                        _tensor_conversion_mirrored_val)
1426
1427
1428# SyncOnReadVariables
1429def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
1430  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1431
1432
1433ops.register_tensor_conversion_function(SyncOnReadVariable,
1434                                        _tensor_conversion_sync_on_read)
1435
1436
1437class VariablePolicy(object):
1438  """Policy defining synchronization and aggregation of a distributed variable.
1439
1440  Given `synchronization` and `aggregation` parameters set on a `tf.Variable`
1441  during variable creation within `tf.distribute` scope, `tf.distribute` creates
1442  an appropriate policy object and assigns it to the distributed variable. All
1443  variable operations are delegated to the respective policy object.
1444  """
1445
1446  def __init__(self, aggregation):
1447    self._aggregation = aggregation
1448
1449  def value(self):
1450    raise NotImplementedError(
1451        "VariablePolicy.value should be overriden by sub-classes.")
1452
1453  def _is_mirrored(self):
1454    raise NotImplementedError(
1455        "VariablePolicy._is_mirrored should be overriden by sub-classes.")
1456
1457  def _as_graph_element(self, _):
1458    raise NotImplementedError(
1459        "VariablePolicy._as_graph_element should be overriden by sub-classes.")
1460
1461  def _get_cross_replica(self, var):
1462    raise NotImplementedError(
1463        "VariablePolicy._get_cross_replica should be overriden by sub-classes.")
1464
1465  def _update_replica(self, var, update_fn, value, **kwargs):
1466    raise NotImplementedError(
1467        "VariablePolicy._update_replica should be overriden by sub-classes.")
1468
1469
1470class OnReadPolicy(VariablePolicy):
1471  """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
1472
1473  This policy is created when `synchronization` is set to
1474  `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
1475  values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
1476  `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
1477  scope.
1478  """
1479
1480  def _is_mirrored(self):
1481    return False
1482
1483  def value(self, var):
1484    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1485      if (ds_context.in_cross_replica_context() and
1486          not values_util.in_replica_update_context()):
1487        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1488          return var._get_replica(0).value()  # pylint: disable=protected-access
1489        return var._get_cross_replica()  # pylint: disable=protected-access
1490      else:
1491        return var._get_on_device_or_primary().value()  # pylint: disable=protected-access
1492
1493  def _as_graph_element(self, var):
1494    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1495      if ds_context.in_cross_replica_context():
1496        return ops.convert_to_tensor(var._get_cross_replica())  # pylint: disable=protected-access
1497    return var._get()._as_graph_element()  # pylint: disable=protected-access
1498
1499  def _get_cross_replica(self, var):
1500    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1501      return var._get_replica(0)  # pylint: disable=protected-access
1502    if self._aggregation == vs.VariableAggregation.SUM:
1503      values_util.mark_as_unsaveable()
1504    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1505      return var.distribute_strategy.reduce(
1506          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1507          var,
1508          axis=None)
1509
1510  def _update_replica(self, var, update_fn, value, **kwargs):
1511    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
1512
1513  def _scatter_not_implemented(self, method):
1514    raise NotImplementedError(f"ON_READ variables doesn't support `{method}` "
1515                              "in cross replica context")
1516
1517  def assign_sub(self,
1518                 var,
1519                 value,
1520                 use_locking=False,
1521                 name=None,
1522                 read_value=True):
1523    """Subtracts a value from this variable."""
1524    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1525      if (ds_context.in_cross_replica_context() and
1526          not values_util.in_replica_update_context()):
1527        values_util.mark_as_unsaveable()
1528        return values_util.on_read_assign_sub_cross_replica(
1529            var, value, read_value=read_value)
1530      else:
1531        return values_util.on_write_assign_sub(
1532            var,
1533            value,
1534            use_locking=use_locking,
1535            name=name,
1536            read_value=read_value)
1537
1538  def assign_add(self,
1539                 var,
1540                 value,
1541                 use_locking=False,
1542                 name=None,
1543                 read_value=True):
1544    """Adds a value to this variable."""
1545    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1546      if (ds_context.in_cross_replica_context() and
1547          not values_util.in_replica_update_context()):
1548        values_util.mark_as_unsaveable()
1549        return values_util.on_read_assign_add_cross_replica(
1550            var, value, read_value=read_value)
1551      else:
1552        return values_util.on_write_assign_add(
1553            var,
1554            value,
1555            use_locking=use_locking,
1556            name=name,
1557            read_value=read_value)
1558
1559  def assign(self, var, value, use_locking=False, name=None, read_value=True):
1560    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
1561      if (ds_context.in_cross_replica_context() and
1562          not values_util.in_replica_update_context()):
1563        values_util.mark_as_unsaveable()
1564        return values_util.on_read_assign_cross_replica(
1565            var, value, read_value=read_value)
1566      else:
1567        return values_util.on_write_assign(
1568            var,
1569            value,
1570            use_locking=use_locking,
1571            name=name,
1572            read_value=read_value)
1573
1574  def scatter_sub(self, *args, **kwargs):
1575    del args, kwargs
1576    self._scatter_not_implemented("scatter_sub")
1577
1578  def scatter_add(self, *args, **kwargs):
1579    del args, kwargs
1580    self._scatter_not_implemented("scatter_add")
1581
1582  def scatter_mul(self, *args, **kwargs):
1583    del args, kwargs
1584    self._scatter_not_implemented("scatter_mul")
1585
1586  def scatter_div(self, *args, **kwargs):
1587    del args, kwargs
1588    self._scatter_not_implemented("scatter_div")
1589
1590  def scatter_min(self, *args, **kwargs):
1591    del args, kwargs
1592    self._scatter_not_implemented("scatter_min")
1593
1594  def scatter_max(self, *args, **kwargs):
1595    del args, kwargs
1596    self._scatter_not_implemented("scatter_max")
1597
1598  def scatter_update(self, *args, **kwargs):
1599    del args, kwargs
1600    self._scatter_not_implemented("scatter_update")
1601
1602  def get_saveable(self, var, primary_var, name):
1603    """Create a saveable object for the given variable."""
1604    return values_util.get_on_read_saveable(var, primary_var, name)
1605
1606  def get_restore_ops(self, var, tensor):
1607    """Restore the same value into all variables."""
1608    return values_util.get_on_read_restore_ops(var, tensor, self._aggregation)
1609
1610
1611class OnWritePolicy(VariablePolicy):
1612  """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
1613
1614  This policy is created when the following `synchronization` and `aggregation`
1615  parameters are specified when creating a `tf.Variable` in `tf.distribute`
1616  scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE`
1617  or `tf.VariableSynchronization.AUTO`.
1618  """
1619
1620  def _is_mirrored(self):
1621    return True
1622
1623  def value(self, var):
1624    return var._get_on_device_or_primary().value()  # pylint: disable=protected-access
1625
1626  def _as_graph_element(self, var):
1627    return var._get_on_device_or_primary()._as_graph_element()  # pylint: disable=protected-access
1628
1629  def _get_cross_replica(self, var):
1630    # Return identity, to avoid directly exposing the variable to the user and
1631    # allowing it to be modified by mistake.
1632    return array_ops.identity(var._get_on_device_or_primary())  # pylint: disable=protected-access
1633
1634  def _update_replica(self, var, update_fn, value, **kwargs):
1635    if var.aggregation == variables_lib.VariableAggregation.NONE:
1636      return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
1637    return _on_write_update_replica(var, update_fn, value, **kwargs)
1638
1639  def assign(self, var, value, use_locking=False, name=None, read_value=True):
1640    return values_util.on_write_assign(
1641        var, value, use_locking=use_locking, name=name, read_value=read_value)
1642
1643  def assign_add(self,
1644                 var,
1645                 value,
1646                 use_locking=False,
1647                 name=None,
1648                 read_value=True):
1649    return values_util.on_write_assign_add(
1650        var, value, use_locking=use_locking, name=name, read_value=read_value)
1651
1652  def assign_sub(self,
1653                 var,
1654                 value,
1655                 use_locking=False,
1656                 name=None,
1657                 read_value=True):
1658    return values_util.on_write_assign_sub(
1659        var, value, use_locking=use_locking, name=name, read_value=read_value)
1660
1661  def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
1662    return values_util.scatter_sub(
1663        var, sparse_delta, use_locking=use_locking, name=name)
1664
1665  def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
1666    return values_util.scatter_add(
1667        var, sparse_delta, use_locking=use_locking, name=name)
1668
1669  def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
1670    return values_util.scatter_mul(
1671        var, sparse_delta, use_locking=use_locking, name=name)
1672
1673  def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
1674    return values_util.scatter_div(
1675        var, sparse_delta, use_locking=use_locking, name=name)
1676
1677  def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
1678    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1679        self._aggregation != vs.VariableAggregation.NONE):
1680      raise NotImplementedError(
1681          values_util.scatter_error_msg.format(
1682              op_name="scatter_min", aggregation=self._aggregation))
1683    return values_util.scatter_min(
1684        var, sparse_delta, use_locking=use_locking, name=name)
1685
1686  def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
1687    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1688        self._aggregation != vs.VariableAggregation.NONE):
1689      raise NotImplementedError(
1690          values_util.scatter_error_msg.format(
1691              op_name="scatter_max", aggregation=self._aggregation))
1692    return values_util.scatter_max(
1693        var, sparse_delta, use_locking=use_locking, name=name)
1694
1695  def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
1696    if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1697        self._aggregation != vs.VariableAggregation.NONE):
1698      raise NotImplementedError(
1699          values_util.scatter_error_msg.format(
1700              op_name="scatter_update", aggregation=self._aggregation))
1701    return values_util.scatter_update(
1702        var, sparse_delta, use_locking=use_locking, name=name)
1703
1704  def get_saveable(self, var, primary_var, name):
1705    """Saveable ops for AUTO variables."""
1706    return values_util.get_on_write_saveable(var, primary_var, name)
1707
1708  def get_restore_ops(self, var, tensor):
1709    return values_util.get_on_write_restore_ops(var, tensor)
1710
1711  def _write_object_proto(self, var, proto, options):
1712    """Update a SavedObject proto for the caller.
1713
1714    If a DistributedVariable object supports this method, it will be called when
1715    saving with a pre-built `SavedObject` proto representing the object, plus an
1716    instance of `SaveOptions`. This method is then free to modify that proto
1717    instance.
1718
1719    `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1720    write out information about their components to the
1721    `experimental_distributed_variable_components` field of a
1722    `SavedVariable` (depending on the `SaveOptions` variable policy).
1723
1724    Args:
1725      var : A DistributedVariable object
1726      proto: A pre-built `SavedObject` proto for this object. It is assumed this
1727        will be a `SavedVariable` instance.
1728      options: A `SaveOptions` instance.
1729    """
1730    values_util.write_object_proto(var, proto, options)
1731
1732
1733class PerWorkerResource():
1734  """A per-worker CapturableResource class for non-ParameterServer strategy.
1735
1736  Resources that populate `host_to_resources` should be instances of classes
1737  subclassing CapturableResource, although currently it's only used and tested
1738  for StaticHashTable with TPUStrategy.
1739  """
1740
1741  def __init__(self, strategy, host_to_resources):
1742    distribute_lib.distribution_strategy_input_api_counter.get_cell(
1743        "PerWorkerResource", "TPUDistributedLookupTable").increase_by(1)
1744    self._strategy = strategy
1745    self._host_to_resources = host_to_resources
1746
1747  def __getattribute__(self, name):
1748    if name not in ("__init__", "__getattribute__", "_host_to_resources",
1749                    "_strategy", "local_resource"):
1750      return getattr(self.local_resource(), name)
1751    return super(PerWorkerResource, self).__getattribute__(name)
1752
1753  def __setattr__(self, name, value):
1754    if name not in ("_strategy", "_host_to_resources"):
1755      return setattr(self.local_resource(), name, value)
1756    return super(PerWorkerResource, self).__setattr__(name, value)
1757
1758  def local_resource(self):
1759    """Returns the resource on the local worker."""
1760    current_device = device_util.canonicalize(device_util.current())
1761    host_device = device_util.canonicalize(
1762        device_util.get_host_for_device(current_device))
1763    return self._host_to_resources.get(
1764        host_device,
1765        self._host_to_resources[next(iter(self._host_to_resources))])
1766