xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/resource_variable_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to use variables as resources."""
16
17# pylint: disable=g-bad-name
18import contextlib
19import functools
20import weakref
21
22import numpy as np
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import variable_pb2
26from tensorflow.python.client import pywrap_tf_session
27from tensorflow.python.compat import compat as forward_compat
28from tensorflow.python.eager import context
29from tensorflow.python.eager import tape
30from tensorflow.python.framework import auto_control_deps_utils as acd
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import cpp_shape_inference_pb2
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import indexed_slices
36from tensorflow.python.framework import meta_graph
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import gen_array_ops
42from tensorflow.python.ops import gen_resource_variable_ops
43from tensorflow.python.ops import gen_state_ops
44from tensorflow.python.ops import handle_data_util
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import state_ops
47from tensorflow.python.ops import variables
48# go/tf-wildcard-import
49# pylint: disable=wildcard-import
50from tensorflow.python.ops.gen_resource_variable_ops import *
51# pylint: enable=wildcard-import
52from tensorflow.python.trackable import base as trackable
53from tensorflow.python.types import core
54from tensorflow.python.util import _pywrap_utils
55from tensorflow.python.util import compat
56from tensorflow.python.util.deprecation import deprecated
57from tensorflow.python.util.tf_export import tf_export
58
59acd.register_read_only_resource_op("ReadVariableOp")
60acd.register_read_only_resource_op("VariableShape")
61acd.register_read_only_resource_op("ResourceGather")
62acd.register_read_only_resource_op("ResourceGatherNd")
63acd.register_read_only_resource_op("_ReadVariablesOp")
64
65# TODO(allenl): Remove this alias and migrate callers.
66get_resource_handle_data = handle_data_util.get_resource_handle_data
67
68
69def get_eager_safe_handle_data(handle):
70  """Get the data handle from the Tensor `handle`."""
71  assert isinstance(handle, ops.Tensor)
72
73  if isinstance(handle, ops.EagerTensor):
74    return handle._handle_data  # pylint: disable=protected-access
75  else:
76    return get_resource_handle_data(handle)
77
78
79def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
80  """Sets the shape inference result HandleData on tensor.
81
82  Args:
83    tensor: A `Tensor` or `EagerTensor`.
84    handle_data: A `CppShapeInferenceResult.HandleData`.
85    graph_mode: A python bool.
86  """
87  tensor._handle_data = handle_data  # pylint: disable=protected-access
88  if not graph_mode:
89    return
90
91  # Not an EagerTensor, so a graph tensor.
92  shapes, types = zip(
93      *[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type])
94  ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
95  shapes = [
96      [d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
97      if not s.unknown_rank else None for s in shapes
98  ]
99  with tensor._op._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
100    pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
101        c_graph,
102        tensor._as_tf_output(),  # pylint: disable=protected-access
103        shapes,
104        ranks,
105        types)
106
107
108def _combine_handle_data(handle, initial_value):
109  """Concats HandleData from tensors `handle` and `initial_value`.
110
111  Args:
112    handle: A `Tensor` of dtype `resource`.
113    initial_value: A `Tensor`.
114
115  Returns:
116    A `CppShapeInferenceResult.HandleData`.  If `initial_value` has dtype
117    `variant`, the `HandleData` contains the concatenation of the shape_and_type
118    from both `handle` and `initial_value`.
119
120  Raises:
121    RuntimeError: If handle, which was returned by VarHandleOp, either has
122      no handle data, or its len(handle_data.shape_and_type) != 1.
123  """
124  assert handle.dtype == dtypes.resource
125
126  variable_handle_data = get_eager_safe_handle_data(handle)
127
128  if initial_value.dtype != dtypes.variant:
129    return variable_handle_data
130
131  extra_handle_data = get_eager_safe_handle_data(initial_value)
132  if extra_handle_data is not None and extra_handle_data.is_set:
133    if (variable_handle_data is None or not variable_handle_data.is_set or
134        len(variable_handle_data.shape_and_type) != 1):
135      raise RuntimeError(
136          "Expected VarHandleOp to return a length==1 shape_and_type, "
137          f"but saw: '{variable_handle_data}'")
138    variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
139  return variable_handle_data
140
141
142def _variable_handle_from_shape_and_dtype(shape,
143                                          dtype,
144                                          shared_name,
145                                          name,
146                                          graph_mode,
147                                          initial_value=None):
148  """Create a variable handle, copying in handle data from `initial_value`."""
149  container = ops.get_default_graph()._container  # pylint: disable=protected-access
150  if container is None:
151    container = ""
152  shape = tensor_shape.as_shape(shape)
153  dtype = dtypes.as_dtype(dtype)
154  if not graph_mode:
155    if shared_name is not None:
156      raise errors.InternalError(
157          node_def=None,
158          op=None,
159          message="Using an explicit shared_name is "
160          "not allowed when executing eagerly.")
161    shared_name = context.anonymous_name()
162
163  handle = gen_resource_variable_ops.var_handle_op(
164      shape=shape,
165      dtype=dtype,
166      shared_name=shared_name,
167      name=name,
168      container=container)
169  if initial_value is None:
170    initial_value = handle
171  if graph_mode:
172    full_handle_data = _combine_handle_data(handle, initial_value)
173    _set_handle_shapes_and_types(handle, full_handle_data, graph_mode)
174    return handle
175  else:
176    handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
177    handle_data.is_set = True
178    handle_data.shape_and_type.append(
179        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
180            shape=shape.as_proto(), dtype=dtype.as_datatype_enum))
181
182    if initial_value is not None and initial_value.dtype == dtypes.variant:
183      extra_handle_data = get_eager_safe_handle_data(initial_value)
184      if extra_handle_data is not None and extra_handle_data.is_set:
185        if (not handle_data.is_set or len(handle_data.shape_and_type) != 1):
186          raise RuntimeError(
187              "Expected VarHandleOp to return a length==1 shape_and_type, "
188              f"but saw: '{handle_data}'")
189        handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
190
191    _set_handle_shapes_and_types(handle, handle_data, graph_mode)
192    return handle
193
194
195def eager_safe_variable_handle(initial_value, shape, shared_name, name,
196                               graph_mode):
197  """Creates a variable handle with information to do shape inference.
198
199  The dtype is read from `initial_value` and stored in the returned
200  resource tensor's handle data.
201
202  If `initial_value.dtype == tf.variant`, we additionally extract the handle
203  data (if any) from `initial_value` and append it to the `handle_data`.
204  In this case, the returned tensor's handle data is in the form
205
206  ```
207  is_set: true
208  shape_and_type {
209    shape {
210      // initial_value.shape
211    }
212    dtype: DT_VARIANT
213  }
214  shape_and_type {
215    // handle_data(initial_value).shape_and_type[0]
216  }
217  shape_and_type {
218    // handle_data(initial_value).shape_and_type[1]
219  }
220  ...
221  ```
222
223  Ops that read from this tensor, such as `ReadVariableOp` and
224  `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]`
225  correspond to the handle data of the variant(s) stored in the Variable.
226
227  Args:
228    initial_value: A `Tensor`.
229    shape: The shape of the handle data. Can be `TensorShape(None)` (i.e.
230      unknown shape).
231    shared_name: A string.
232    name: A string.
233    graph_mode: A python bool.
234
235  Returns:
236    The handle, a `Tensor` of type `resource`.
237  """
238  dtype = initial_value.dtype.base_dtype
239  return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name,
240                                               graph_mode, initial_value)
241
242
243@contextlib.contextmanager
244def _handle_graph(handle):
245  # Note: might have an eager tensor but not be executing eagerly when building
246  # functions.
247  if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or
248      ops.has_default_graph()):
249    yield
250  else:
251    with handle.graph.as_default():
252      yield
253
254
255class EagerResourceDeleter:
256  """An object which cleans up a resource handle.
257
258  An alternative to defining a __del__ method on an object. The intended use is
259  that ResourceVariables or other objects with resource handles will maintain a
260  single reference to this object. When the parent object is collected, this
261  object will be too. Even if the parent object is part of a reference cycle,
262  the cycle will be collectable.
263  """
264
265  __slots__ = ["_handle", "_handle_device", "_context"]
266
267  def __init__(self, handle, handle_device):
268    if not isinstance(handle, ops.Tensor):
269      raise ValueError(
270          (f"Passed handle={handle} to EagerResourceDeleter. Was expecting "
271           f"the handle to be a `tf.Tensor`."))
272    self._handle = handle
273    self._handle_device = handle_device
274    # This is held since the __del__ function runs an op, and if the context()
275    # is collected before this object, there will be a segfault when running the
276    # op.
277    self._context = context.context()
278
279  def __del__(self):
280    # Resources follow object-identity when executing eagerly, so it is safe to
281    # delete the resource we have a handle to.
282    try:
283      # A packed EagerTensor doesn't own any resource.
284      if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed:
285        return
286      # This resource was created in eager mode. However, this destructor may be
287      # running in graph mode (especially during unit tests). To clean up
288      # successfully, we switch back into eager mode temporarily.
289      with context.eager_mode():
290        with ops.device(self._handle_device):
291          gen_resource_variable_ops.destroy_resource_op(
292              self._handle, ignore_lookup_error=True)
293    except TypeError:
294      # Suppress some exceptions, mainly for the case when we're running on
295      # module deletion. Things that can go wrong include the context module
296      # already being unloaded, self._handle._handle_data no longer being
297      # valid, and so on. Printing warnings in these cases is silly
298      # (exceptions raised from __del__ are printed as warnings to stderr).
299      pass  # 'NoneType' object is not callable when the handle has been
300      # partially unloaded.
301    except AttributeError:
302      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
303      # been unloaded. Will catch other module unloads as well.
304
305
306def shape_safe_assign_variable_handle(handle, shape, value, name=None):
307  """Helper that checks shape compatibility and assigns variable."""
308  with _handle_graph(handle):
309    value_tensor = ops.convert_to_tensor(value)
310  shape.assert_is_compatible_with(value_tensor.shape)
311  return gen_resource_variable_ops.assign_variable_op(
312      handle, value_tensor, name=name)
313
314
315def _maybe_set_handle_data(dtype, handle, tensor):
316  if dtype == dtypes.variant:
317    # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
318    # variant's handle data.  Extract it.
319    handle_data = get_eager_safe_handle_data(handle)
320    if handle_data.is_set and len(handle_data.shape_and_type) > 1:
321      tensor._handle_data = (  # pylint: disable=protected-access
322          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
323              is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
324
325
326def variable_accessed(variable):
327  """Records that `variable` was accessed for the tape and FuncGraph."""
328  if hasattr(ops.get_default_graph(), "watch_variable"):
329    ops.get_default_graph().watch_variable(variable)
330  if variable.trainable:
331    tape.variable_accessed(variable)
332
333
334class BaseResourceVariable(variables.VariableV1, core.Tensor):
335  """A python variable from an existing handle."""
336
337  # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
338  def __init__(  # pylint: disable=super-init-not-called
339      self,
340      trainable=None,
341      shape=None,
342      dtype=None,
343      handle=None,
344      constraint=None,
345      synchronization=None,
346      aggregation=None,
347      distribute_strategy=None,
348      name=None,
349      unique_id=None,
350      handle_name=None,
351      graph_element=None,
352      initial_value=None,
353      initializer_op=None,
354      is_initialized_op=None,
355      cached_value=None,
356      save_slice_info=None,
357      caching_device=None,
358      in_graph_mode=None,
359      validate_shape=True,
360      **unused_kwargs):
361    """Creates a variable from a handle.
362
363    Args:
364      trainable: If `True`, GradientTapes automatically watch uses of this
365        Variable.
366      shape: The variable's shape. This shape can be set to tf.TensorShape(None)
367        in order to assign values of different shapes to this variable.
368        Otherwise (i.e. if the shape is fully determined), it will trigger run
369        time checks to ensure that each assignment is of the same shape.
370      dtype: The variable's dtype.
371      handle: The variable's handle
372      constraint: An optional projection function to be applied to the variable
373        after being updated by an `Optimizer` (e.g. used to implement norm
374        constraints or value constraints for layer weights). The function must
375        take as input the unprojected Tensor representing the value of the
376        variable and return the Tensor for the projected value (which must have
377        the same shape). Constraints are not safe to use when doing asynchronous
378        distributed training.
379      synchronization: Indicates when a distributed a variable will be
380        aggregated. Accepted values are constants defined in the class
381        `tf.VariableSynchronization`. By default the synchronization is set to
382        `AUTO` and the current `DistributionStrategy` chooses when to
383        synchronize.
384      aggregation: Indicates how a distributed variable will be aggregated.
385        Accepted values are constants defined in the class
386        `tf.VariableAggregation`.
387      distribute_strategy: The distribution strategy this variable was created
388        under.
389      name: The name for this variable.
390      unique_id: Internal. Unique ID for this variable's handle.
391      handle_name: The name for the variable's handle.
392      graph_element: Optional, required only in session.run-mode. Pre-created
393        tensor which reads this variable's value.
394      initial_value: Optional. Variable's initial value.
395      initializer_op: Operation which assigns the variable's initial value.
396      is_initialized_op: Pre-created operation to check whether this variable is
397        initialized.
398      cached_value: Pre-created operation to read this variable in a specific
399        device.
400      save_slice_info: Metadata for variable partitioning.
401      caching_device: Optional device string or function describing where the
402        Variable should be cached for reading.  Defaults to the Variable's
403        device.  If not `None`, caches on another device.  Typical use is to
404        cache on the device where the Ops using the Variable reside, to
405        deduplicate copying through `Switch` and other conditional statements.
406      in_graph_mode: whether we are executing in TF1 graph mode. If None, will
407        detect within the function. This is to avoid repeated init_scope()
408        conetxt entrances which can add up.
409      validate_shape: If `False`, allows the variable to be initialized with a
410        value of unknown shape. If `True`, the default, the shape of
411        `initial_value` must be known.
412    """
413    if in_graph_mode is None:
414      with ops.init_scope():
415        self._in_graph_mode = not context.executing_eagerly()
416    else:
417      self._in_graph_mode = in_graph_mode
418    synchronization, aggregation, trainable = (
419        variables.validate_synchronization_aggregation_trainable(
420            synchronization, aggregation, trainable, name))
421    self._trainable = trainable
422    self._synchronization = synchronization
423    self._aggregation = aggregation
424    self._save_slice_info = save_slice_info
425    self._initial_value = initial_value
426    self._initializer_op = initializer_op
427    self._is_initialized_op = is_initialized_op
428    self._graph_element = graph_element
429    self._caching_device = caching_device
430    self._cached_value = cached_value
431    self._distribute_strategy = distribute_strategy
432    # Store the graph key so optimizers know how to only retrieve variables from
433    # this graph. Guaranteed to be the same as the eager graph_key.
434    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
435    self._shape = tensor_shape.as_shape(shape)
436    self._dtype = dtypes.as_dtype(dtype)
437    self._handle = handle
438    self._unique_id = unique_id
439    if handle_name is None:
440      self._handle_name = "Variable:0"
441    else:
442      self._handle_name = handle_name + ":0"
443    self._constraint = constraint
444    self._cached_shape_as_list = None
445    self._validate_shape = validate_shape
446
447  def __repr__(self):
448    if context.executing_eagerly() and not self._in_graph_mode:
449      # If we cannot read the value for any reason (e.g. variable uninitialized
450      # during tf.function tracing), still produce a __repr__. Note that for
451      # async eager, errors due to uninitialized variables will raise in
452      # ops.value_text when the handle is resolved, so we need to keep that
453      # under the try...except if we want to suppress them.
454      try:
455        with ops.device(self.device):
456          value_text = ops.value_text(self.read_value(), is_repr=True)
457      except:  # pylint: disable=bare-except
458        value_text = "numpy=<unavailable>"
459
460      return "<tf.Variable '%s' shape=%s dtype=%s, %s>" % (
461          self.name, self.get_shape(), self.dtype.name, value_text)
462    else:
463      return "<tf.Variable '%s' shape=%s dtype=%s>" % (
464          self.name, self.get_shape(), self.dtype.name)
465
466  def __tf_tracing_type__(self, signature_context):
467    return signature_context.make_reference_type(
468        VariableSpec(self.shape, self.dtype), self._handle._id)  # pylint:disable=protected-access
469
470  @contextlib.contextmanager
471  def _assign_dependencies(self):
472    """Makes assignments depend on the cached value, if any.
473
474    This prevents undefined behavior with reads not ordered wrt writes.
475
476    Yields:
477      None.
478    """
479    if self._cached_value is not None:
480      with ops.control_dependencies([self._cached_value]):
481        yield
482    else:
483      yield
484
485  def __array__(self, dtype=None):
486    """Allows direct conversion to a numpy array.
487
488    >>> np.array(tf.Variable([1.0]))
489    array([1.], dtype=float32)
490
491    Returns:
492      The variable value as a numpy array.
493    """
494    # You can't return `self.numpy()` here because for scalars
495    # that raises:
496    #     ValueError: object __array__ method not producing an array
497    # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give
498    # the same error. The `EagerTensor` class must be doing something behind the
499    # scenes to make `np.array(tf.constant(1))` work.
500    return np.asarray(self.numpy(), dtype=dtype)
501
502  def __nonzero__(self):
503    return self.__bool__()
504
505  def __bool__(self):
506    return bool(self.read_value())
507
508  def __copy__(self):
509    return self
510
511  def __deepcopy__(self, memo):
512    if not context.executing_eagerly():
513      raise NotImplementedError(
514          "__deepcopy__() is only available when eager execution is enabled.")
515    copied_variable = ResourceVariable(
516        initial_value=self.read_value(),
517        trainable=self._trainable,
518        constraint=self._constraint,
519        dtype=self._dtype,
520        name=self._shared_name,
521        distribute_strategy=self._distribute_strategy,
522        synchronization=self.synchronization,
523        aggregation=self.aggregation)
524    memo[self._unique_id] = copied_variable
525    return copied_variable
526
527  @property
528  def dtype(self):
529    """The dtype of this variable."""
530    return self._dtype
531
532  @property
533  def device(self):
534    """The device this variable is on."""
535    return self.handle.device
536
537  @property
538  def graph(self):
539    """The `Graph` of this variable."""
540    return self.handle.graph
541
542  @property
543  def name(self):
544    """The name of the handle for this variable."""
545    return self._handle_name
546
547  @property
548  def shape(self):
549    """The shape of this variable."""
550    return self._shape
551
552  def set_shape(self, shape):
553    self._shape = self._shape.merge_with(shape)
554
555  def _shape_as_list(self):
556    if self.shape.ndims is None:
557      return None
558    return [dim.value for dim in self.shape.dims]
559
560  def _shape_tuple(self):
561    shape = self._shape_as_list()
562    if shape is None:
563      return None
564    return tuple(shape)
565
566  @property
567  def create(self):
568    """The op responsible for initializing this variable."""
569    if not self._in_graph_mode:
570      raise RuntimeError("This operation is not supported "
571                         "when eager execution is enabled.")
572    return self._initializer_op
573
574  @property
575  def handle(self):
576    """The handle by which this variable can be accessed."""
577    return self._handle
578
579  def value(self):
580    """A cached operation which reads the value of this variable."""
581    if self._cached_value is not None:
582      return self._cached_value
583    with ops.colocate_with(None, ignore_existing=True):
584      return self._read_variable_op()
585
586  def _as_graph_element(self):
587    """Conversion function for Graph.as_graph_element()."""
588    return self._graph_element
589
590  @property
591  def initializer(self):
592    """The op responsible for initializing this variable."""
593    return self._initializer_op
594
595  @property
596  def initial_value(self):
597    """Returns the Tensor used as the initial value for the variable."""
598    if context.executing_eagerly():
599      raise RuntimeError("This property is not supported "
600                         "when eager execution is enabled.")
601    return self._initial_value
602
603  @property
604  def constraint(self):
605    """Returns the constraint function associated with this variable.
606
607    Returns:
608      The constraint function that was passed to the variable constructor.
609      Can be `None` if no constraint was passed.
610    """
611    return self._constraint
612
613  @property
614  def op(self):
615    """The op for this variable."""
616    return self.handle.op
617
618  @property
619  def trainable(self):
620    return self._trainable
621
622  @property
623  def synchronization(self):
624    return self._synchronization
625
626  @property
627  def aggregation(self):
628    return self._aggregation
629
630  def eval(self, session=None):
631    """Evaluates and returns the value of this variable."""
632    if context.executing_eagerly():
633      raise RuntimeError("This operation is not supported "
634                         "when eager execution is enabled.")
635    return self._graph_element.eval(session=session)
636
637  def numpy(self):
638    if context.executing_eagerly():
639      return self.read_value().numpy()
640    raise NotImplementedError(
641        "numpy() is only available when eager execution is enabled.")
642
643  @deprecated(None, "Prefer Dataset.range instead.")
644  def count_up_to(self, limit):
645    """Increments this variable until it reaches `limit`.
646
647    When that Op is run it tries to increment the variable by `1`. If
648    incrementing the variable would bring it above `limit` then the Op raises
649    the exception `OutOfRangeError`.
650
651    If no error is raised, the Op outputs the value of the variable before
652    the increment.
653
654    This is essentially a shortcut for `count_up_to(self, limit)`.
655
656    Args:
657      limit: value at which incrementing the variable raises an error.
658
659    Returns:
660      A `Tensor` that will hold the variable value before the increment. If no
661      other Op modifies this variable, the values produced will all be
662      distinct.
663    """
664    return gen_state_ops.resource_count_up_to(
665        self.handle, limit=limit, T=self.dtype)
666
667  def _map_resources(self, save_options):
668    """For implementing `Trackable`."""
669    new_variable = None
670    if save_options.experimental_variable_policy._save_variable_devices():  # pylint:disable=protected-access
671      with ops.device(self.device):
672        new_variable = copy_to_graph_uninitialized(self)
673    else:
674      new_variable = copy_to_graph_uninitialized(self)
675    obj_map = {self: new_variable}
676    resource_map = {self.handle: new_variable.handle}
677    return obj_map, resource_map
678
679  def _read_variable_op(self, no_copy=False):
680    """Reads the value of the variable.
681
682    If the variable is in copy-on-read mode and `no_copy` is True, the variable
683    is converted to copy-on-write mode before it is read.
684
685    Args:
686      no_copy: Whether to prevent a copy of the variable.
687
688    Returns:
689      The value of the variable.
690    """
691    variable_accessed(self)
692
693    def read_and_set_handle(no_copy):
694      if no_copy and forward_compat.forward_compatible(2022, 5, 3):
695        gen_resource_variable_ops.disable_copy_on_read(self.handle)
696      result = gen_resource_variable_ops.read_variable_op(
697          self.handle, self._dtype)
698      _maybe_set_handle_data(self._dtype, self.handle, result)
699      return result
700
701    if getattr(self, "_caching_device", None) is not None:
702      with ops.colocate_with(None, ignore_existing=True):
703        with ops.device(self._caching_device):
704          result = read_and_set_handle(no_copy)
705    else:
706      result = read_and_set_handle(no_copy)
707
708    if not context.executing_eagerly():
709      # Note that if a control flow context is active the input of the read op
710      # might not actually be the handle. This line bypasses it.
711      tape.record_operation(
712          "ReadVariableOp", [result], [self.handle],
713          backward_function=lambda x: [x],
714          forward_function=lambda x: [x])
715    return result
716
717  def read_value(self):
718    """Constructs an op which reads the value of this variable.
719
720    Should be used when there are multiple reads, or when it is desirable to
721    read the value only after some condition is true.
722
723    Returns:
724      The value of the variable.
725    """
726    with ops.name_scope("Read"):
727      value = self._read_variable_op()
728    # Return an identity so it can get placed on whatever device the context
729    # specifies instead of the device where the variable is.
730    return array_ops.identity(value)
731
732  def read_value_no_copy(self):
733    """Constructs an op which reads the value of this variable without copy.
734
735    The variable is read without making a copy even when it has been sparsely
736    accessed. Variables in copy-on-read mode will be converted to copy-on-write
737    mode.
738
739    Returns:
740      The value of the variable.
741    """
742    with ops.name_scope("Read"):
743      value = self._read_variable_op(no_copy=True)
744    # Return an identity so it can get placed on whatever device the context
745    # specifies instead of the device where the variable is.
746    return array_ops.identity(value)
747
748  def sparse_read(self, indices, name=None):
749    """Reads the value of this variable sparsely, using `gather`."""
750    with ops.name_scope("Gather" if name is None else name) as name:
751      variable_accessed(self)
752      value = gen_resource_variable_ops.resource_gather(
753          self.handle, indices, dtype=self._dtype, name=name)
754
755      if self._dtype == dtypes.variant:
756        # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
757        # variant's handle data.  Extract it.
758        handle_data = get_eager_safe_handle_data(self.handle)
759        if handle_data.is_set and len(handle_data.shape_and_type) > 1:
760          value._handle_data = (  # pylint: disable=protected-access
761              cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
762                  is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
763
764    return array_ops.identity(value)
765
766  def gather_nd(self, indices, name=None):
767    """Reads the value of this variable sparsely, using `gather_nd`."""
768    with ops.name_scope("GatherNd" if name is None else name) as name:
769      if self.trainable:
770        variable_accessed(self)
771      value = gen_resource_variable_ops.resource_gather_nd(
772          self.handle, indices, dtype=self._dtype, name=name)
773
774    return array_ops.identity(value)
775
776  def to_proto(self, export_scope=None):
777    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
778
779    Args:
780      export_scope: Optional `string`. Name scope to remove.
781
782    Raises:
783      RuntimeError: If run in EAGER mode.
784
785    Returns:
786      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
787      in the specified name scope.
788    """
789    if context.executing_eagerly():
790      raise RuntimeError("This operation is not supported "
791                         "when eager execution is enabled.")
792    if export_scope is None or self.handle.name.startswith(export_scope):
793      var_def = variable_pb2.VariableDef()
794      var_def.variable_name = ops.strip_name_scope(self.handle.name,
795                                                   export_scope)
796      if self._initial_value is not None:
797        # This is inside an if-statement for backwards compatibility, since
798        # self._initial_value might be None for variables constructed from old
799        # protos.
800        var_def.initial_value_name = ops.strip_name_scope(
801            self._initial_value.name, export_scope)
802      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
803                                                      export_scope)
804      if self._cached_value is not None:
805        var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
806                                                     export_scope)
807      else:
808        # Store the graph_element here
809        var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
810                                                     export_scope)
811      var_def.is_resource = True
812      var_def.trainable = self.trainable
813      var_def.synchronization = self.synchronization.value
814      var_def.aggregation = self.aggregation.value
815      if self._save_slice_info:
816        var_def.save_slice_info_def.MergeFrom(
817            self._save_slice_info.to_proto(export_scope=export_scope))
818      return var_def
819    else:
820      return None
821
822  @staticmethod
823  def from_proto(variable_def, import_scope=None):
824    if context.executing_eagerly():
825      raise RuntimeError("This operation is not supported "
826                         "when eager execution is enabled.")
827    return ResourceVariable(
828        variable_def=variable_def, import_scope=import_scope)
829
830  __array_priority__ = 100
831
832  def is_initialized(self, name=None):
833    """Checks whether a resource variable has been initialized.
834
835    Outputs boolean scalar indicating whether the tensor has been initialized.
836
837    Args:
838      name: A name for the operation (optional).
839
840    Returns:
841      A `Tensor` of type `bool`.
842    """
843    return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
844
845  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
846    """Subtracts a value from this variable.
847
848    Args:
849      delta: A `Tensor`. The value to subtract from this variable.
850      use_locking: If `True`, use locking during the operation.
851      name: The name to use for the operation.
852      read_value: A `bool`. Whether to read and return the new value of the
853        variable or not.
854
855    Returns:
856      If `read_value` is `True`, this method will return the new value of the
857      variable after the assignment has completed. Otherwise, when in graph mode
858      it will return the `Operation` that does the assignment, and when in eager
859      mode it will return `None`.
860    """
861    # TODO(apassos): this here and below is not atomic. Consider making it
862    # atomic if there's a way to do so without a performance cost for those who
863    # don't need it.
864    with _handle_graph(self.handle), self._assign_dependencies():
865      assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
866          self.handle,
867          ops.convert_to_tensor(delta, dtype=self.dtype),
868          name=name)
869    if read_value:
870      return self._lazy_read(assign_sub_op)
871    return assign_sub_op
872
873  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
874    """Adds a value to this variable.
875
876    Args:
877      delta: A `Tensor`. The value to add to this variable.
878      use_locking: If `True`, use locking during the operation.
879      name: The name to use for the operation.
880      read_value: A `bool`. Whether to read and return the new value of the
881        variable or not.
882
883    Returns:
884      If `read_value` is `True`, this method will return the new value of the
885      variable after the assignment has completed. Otherwise, when in graph mode
886      it will return the `Operation` that does the assignment, and when in eager
887      mode it will return `None`.
888    """
889    with _handle_graph(self.handle), self._assign_dependencies():
890      assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
891          self.handle,
892          ops.convert_to_tensor(delta, dtype=self.dtype),
893          name=name)
894    if read_value:
895      return self._lazy_read(assign_add_op)
896    return assign_add_op
897
898  def _lazy_read(self, op):
899    variable_accessed(self)
900    return _UnreadVariable(
901        handle=self.handle,
902        dtype=self.dtype,
903        shape=self._shape,
904        in_graph_mode=self._in_graph_mode,
905        parent_op=op,
906        unique_id=self._unique_id)
907
908  def assign(self, value, use_locking=None, name=None, read_value=True):
909    """Assigns a new value to this variable.
910
911    Args:
912      value: A `Tensor`. The new value for this variable.
913      use_locking: If `True`, use locking during the assignment.
914      name: The name to use for the assignment.
915      read_value: A `bool`. Whether to read and return the new value of the
916        variable or not.
917
918    Returns:
919      If `read_value` is `True`, this method will return the new value of the
920      variable after the assignment has completed. Otherwise, when in graph mode
921      it will return the `Operation` that does the assignment, and when in eager
922      mode it will return `None`.
923    """
924    # Note: not depending on the cached value here since this can be used to
925    # initialize the variable.
926    with _handle_graph(self.handle):
927      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
928      if not self._shape.is_compatible_with(value_tensor.shape):
929        if self.name is None:
930          tensor_name = ""
931        else:
932          tensor_name = " " + str(self.name)
933        raise ValueError(
934            (f"Cannot assign value to variable '{tensor_name}': Shape mismatch."
935             f"The variable shape {self._shape}, and the "
936             f"assigned value shape {value_tensor.shape} are incompatible."))
937      kwargs = {}
938      if forward_compat.forward_compatible(2022, 3, 23):
939        # If the shape is fully defined, we do a runtime check with the shape of
940        # value.
941        validate_shape = self._validate_shape and self._shape.is_fully_defined()
942        kwargs["validate_shape"] = validate_shape
943      assign_op = gen_resource_variable_ops.assign_variable_op(
944          self.handle, value_tensor, name=name, **kwargs)
945      if read_value:
946        return self._lazy_read(assign_op)
947    return assign_op
948
949  def __reduce__(self):
950    # The implementation mirrors that of __deepcopy__.
951    return functools.partial(
952        ResourceVariable,
953        initial_value=self.numpy(),
954        trainable=self.trainable,
955        name=self._shared_name,
956        dtype=self.dtype,
957        constraint=self.constraint,
958        distribute_strategy=self._distribute_strategy), ()
959
960  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
961    """Subtracts `tf.IndexedSlices` from this variable.
962
963    Args:
964      sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
965      use_locking: If `True`, use locking during the operation.
966      name: the name of the operation.
967
968    Returns:
969      The updated variable.
970
971    Raises:
972      TypeError: if `sparse_delta` is not an `IndexedSlices`.
973    """
974    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
975      raise TypeError(f"Argument `sparse_delta` must be a "
976                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
977    return self._lazy_read(
978        gen_resource_variable_ops.resource_scatter_sub(
979            self.handle,
980            sparse_delta.indices,
981            ops.convert_to_tensor(sparse_delta.values, self.dtype),
982            name=name))
983
984  def scatter_add(self, sparse_delta, use_locking=False, name=None):
985    """Adds `tf.IndexedSlices` to this variable.
986
987    Args:
988      sparse_delta: `tf.IndexedSlices` to be added to this variable.
989      use_locking: If `True`, use locking during the operation.
990      name: the name of the operation.
991
992    Returns:
993      The updated variable.
994
995    Raises:
996      TypeError: if `sparse_delta` is not an `IndexedSlices`.
997    """
998    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
999      raise TypeError(f"Argument `sparse_delta` must be a "
1000                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1001    return self._lazy_read(
1002        gen_resource_variable_ops.resource_scatter_add(
1003            self.handle,
1004            sparse_delta.indices,
1005            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1006            name=name))
1007
1008  def scatter_max(self, sparse_delta, use_locking=False, name=None):
1009    """Updates this variable with the max of `tf.IndexedSlices` and itself.
1010
1011    Args:
1012      sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
1013        variable.
1014      use_locking: If `True`, use locking during the operation.
1015      name: the name of the operation.
1016
1017    Returns:
1018      The updated variable.
1019
1020    Raises:
1021      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1022    """
1023    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1024      raise TypeError(f"Argument `sparse_delta` must be a "
1025                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1026    return self._lazy_read(
1027        gen_resource_variable_ops.resource_scatter_max(
1028            self.handle,
1029            sparse_delta.indices,
1030            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1031            name=name))
1032
1033  def scatter_min(self, sparse_delta, use_locking=False, name=None):
1034    """Updates this variable with the min of `tf.IndexedSlices` and itself.
1035
1036    Args:
1037      sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
1038        variable.
1039      use_locking: If `True`, use locking during the operation.
1040      name: the name of the operation.
1041
1042    Returns:
1043      The updated variable.
1044
1045    Raises:
1046      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1047    """
1048    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1049      raise TypeError(f"Argument `sparse_delta` must be a "
1050                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1051    return self._lazy_read(
1052        gen_resource_variable_ops.resource_scatter_min(
1053            self.handle,
1054            sparse_delta.indices,
1055            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1056            name=name))
1057
1058  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
1059    """Multiply this variable by `tf.IndexedSlices`.
1060
1061    Args:
1062      sparse_delta: `tf.IndexedSlices` to multiply this variable by.
1063      use_locking: If `True`, use locking during the operation.
1064      name: the name of the operation.
1065
1066    Returns:
1067      The updated variable.
1068
1069    Raises:
1070      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1071    """
1072    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1073      raise TypeError(f"Argument `sparse_delta` must be a "
1074                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1075    return self._lazy_read(
1076        gen_resource_variable_ops.resource_scatter_mul(
1077            self.handle,
1078            sparse_delta.indices,
1079            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1080            name=name))
1081
1082  def scatter_div(self, sparse_delta, use_locking=False, name=None):
1083    """Divide this variable by `tf.IndexedSlices`.
1084
1085    Args:
1086      sparse_delta: `tf.IndexedSlices` to divide this variable by.
1087      use_locking: If `True`, use locking during the operation.
1088      name: the name of the operation.
1089
1090    Returns:
1091      The updated variable.
1092
1093    Raises:
1094      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1095    """
1096    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1097      raise TypeError(f"Argument `sparse_delta` must be a "
1098                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1099    return self._lazy_read(
1100        gen_resource_variable_ops.resource_scatter_div(
1101            self.handle,
1102            sparse_delta.indices,
1103            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1104            name=name))
1105
1106  def scatter_update(self, sparse_delta, use_locking=False, name=None):
1107    """Assigns `tf.IndexedSlices` to this variable.
1108
1109    Args:
1110      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1111      use_locking: If `True`, use locking during the operation.
1112      name: the name of the operation.
1113
1114    Returns:
1115      The updated variable.
1116
1117    Raises:
1118      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1119    """
1120    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1121      raise TypeError(f"Argument `sparse_delta` must be a "
1122                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1123    return self._lazy_read(
1124        gen_resource_variable_ops.resource_scatter_update(
1125            self.handle,
1126            sparse_delta.indices,
1127            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1128            name=name))
1129
1130  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
1131    """Assigns `tf.IndexedSlices` to this variable batch-wise.
1132
1133    Analogous to `batch_gather`. This assumes that this variable and the
1134    sparse_delta IndexedSlices have a series of leading dimensions that are the
1135    same for all of them, and the updates are performed on the last dimension of
1136    indices. In other words, the dimensions should be the following:
1137
1138    `num_prefix_dims = sparse_delta.indices.ndims - 1`
1139    `batch_dim = num_prefix_dims + 1`
1140    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
1141         batch_dim:]`
1142
1143    where
1144
1145    `sparse_delta.updates.shape[:num_prefix_dims]`
1146    `== sparse_delta.indices.shape[:num_prefix_dims]`
1147    `== var.shape[:num_prefix_dims]`
1148
1149    And the operation performed can be expressed as:
1150
1151    `var[i_1, ..., i_n,
1152         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
1153            i_1, ..., i_n, j]`
1154
1155    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
1156    `scatter_update`.
1157
1158    To avoid this operation one can looping over the first `ndims` of the
1159    variable and using `scatter_update` on the subtensors that result of slicing
1160    the first dimension. This is a valid option for `ndims = 1`, but less
1161    efficient than this implementation.
1162
1163    Args:
1164      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1165      use_locking: If `True`, use locking during the operation.
1166      name: the name of the operation.
1167
1168    Returns:
1169      The updated variable.
1170
1171    Raises:
1172      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1173    """
1174    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
1175      raise TypeError(f"Argument `sparse_delta` must be a "
1176                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1177    return self._lazy_read(
1178        state_ops.batch_scatter_update(
1179            self,
1180            sparse_delta.indices,
1181            sparse_delta.values,
1182            use_locking=use_locking,
1183            name=name))
1184
1185  def scatter_nd_sub(self, indices, updates, name=None):
1186    """Applies sparse subtraction to individual values or slices in a Variable.
1187
1188    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1189
1190    `indices` must be integer tensor, containing indices into `ref`.
1191    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1192
1193    The innermost dimension of `indices` (with length `K`) corresponds to
1194    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1195    dimension of `ref`.
1196
1197    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1198
1199    ```
1200    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1201    ```
1202
1203    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1204    8 elements. In Python, that update would look like this:
1205
1206    ```python
1207        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1208        indices = tf.constant([[4], [3], [1] ,[7]])
1209        updates = tf.constant([9, 10, 11, 12])
1210        op = ref.scatter_nd_sub(indices, updates)
1211        with tf.compat.v1.Session() as sess:
1212          print sess.run(op)
1213    ```
1214
1215    The resulting update to ref would look like this:
1216
1217        [1, -9, 3, -6, -6, 6, 7, -4]
1218
1219    See `tf.scatter_nd` for more details about how to make updates to
1220    slices.
1221
1222    Args:
1223      indices: The indices to be used in the operation.
1224      updates: The values to be used in the operation.
1225      name: the name of the operation.
1226
1227    Returns:
1228      The updated variable.
1229    """
1230    return self._lazy_read(
1231        gen_state_ops.resource_scatter_nd_sub(
1232            self.handle,
1233            indices,
1234            ops.convert_to_tensor(updates, self.dtype),
1235            name=name))
1236
1237  def scatter_nd_add(self, indices, updates, name=None):
1238    """Applies sparse addition to individual values or slices in a Variable.
1239
1240    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1241
1242    `indices` must be integer tensor, containing indices into `ref`.
1243    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1244
1245    The innermost dimension of `indices` (with length `K`) corresponds to
1246    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1247    dimension of `ref`.
1248
1249    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1250
1251    ```
1252    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1253    ```
1254
1255    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1256    8 elements. In Python, that update would look like this:
1257
1258    ```python
1259        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1260        indices = tf.constant([[4], [3], [1] ,[7]])
1261        updates = tf.constant([9, 10, 11, 12])
1262        add = ref.scatter_nd_add(indices, updates)
1263        with tf.compat.v1.Session() as sess:
1264          print sess.run(add)
1265    ```
1266
1267    The resulting update to ref would look like this:
1268
1269        [1, 13, 3, 14, 14, 6, 7, 20]
1270
1271    See `tf.scatter_nd` for more details about how to make updates to
1272    slices.
1273
1274    Args:
1275      indices: The indices to be used in the operation.
1276      updates: The values to be used in the operation.
1277      name: the name of the operation.
1278
1279    Returns:
1280      The updated variable.
1281    """
1282    return self._lazy_read(
1283        gen_state_ops.resource_scatter_nd_add(
1284            self.handle,
1285            indices,
1286            ops.convert_to_tensor(updates, self.dtype),
1287            name=name))
1288
1289  def scatter_nd_update(self, indices, updates, name=None):
1290    """Applies sparse assignment to individual values or slices in a Variable.
1291
1292    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1293
1294    `indices` must be integer tensor, containing indices into `ref`.
1295    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1296
1297    The innermost dimension of `indices` (with length `K`) corresponds to
1298    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1299    dimension of `ref`.
1300
1301    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1302
1303    ```
1304    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1305    ```
1306
1307    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1308    8 elements. In Python, that update would look like this:
1309
1310    ```python
1311        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1312        indices = tf.constant([[4], [3], [1] ,[7]])
1313        updates = tf.constant([9, 10, 11, 12])
1314        op = ref.scatter_nd_update(indices, updates)
1315        with tf.compat.v1.Session() as sess:
1316          print sess.run(op)
1317    ```
1318
1319    The resulting update to ref would look like this:
1320
1321        [1, 11, 3, 10, 9, 6, 7, 12]
1322
1323    See `tf.scatter_nd` for more details about how to make updates to
1324    slices.
1325
1326    Args:
1327      indices: The indices to be used in the operation.
1328      updates: The values to be used in the operation.
1329      name: the name of the operation.
1330
1331    Returns:
1332      The updated variable.
1333    """
1334    return self._lazy_read(
1335        gen_state_ops.resource_scatter_nd_update(
1336            self.handle,
1337            indices,
1338            ops.convert_to_tensor(updates, self.dtype),
1339            name=name))
1340
1341  def scatter_nd_max(self, indices, updates, name=None):
1342    """Updates this variable with the max of `tf.IndexedSlices` and itself.
1343
1344    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1345
1346    `indices` must be integer tensor, containing indices into `ref`.
1347    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1348
1349    The innermost dimension of `indices` (with length `K`) corresponds to
1350    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1351    dimension of `ref`.
1352
1353    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1354
1355    ```
1356    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1357    ```
1358
1359    See `tf.scatter_nd` for more details about how to make updates to
1360    slices.
1361
1362    Args:
1363      indices: The indices to be used in the operation.
1364      updates: The values to be used in the operation.
1365      name: the name of the operation.
1366
1367    Returns:
1368      The updated variable.
1369    """
1370    return self._lazy_read(
1371        gen_state_ops.resource_scatter_nd_max(
1372            self.handle,
1373            indices,
1374            ops.convert_to_tensor(updates, self.dtype),
1375            name=name))
1376
1377  def scatter_nd_min(self, indices, updates, name=None):
1378    """Updates this variable with the min of `tf.IndexedSlices` and itself.
1379
1380    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1381
1382    `indices` must be integer tensor, containing indices into `ref`.
1383    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1384
1385    The innermost dimension of `indices` (with length `K`) corresponds to
1386    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1387    dimension of `ref`.
1388
1389    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1390
1391    ```
1392    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1393    ```
1394
1395    See `tf.scatter_nd` for more details about how to make updates to
1396    slices.
1397
1398    Args:
1399      indices: The indices to be used in the operation.
1400      updates: The values to be used in the operation.
1401      name: the name of the operation.
1402
1403    Returns:
1404      The updated variable.
1405    """
1406    return self._lazy_read(
1407        gen_state_ops.resource_scatter_nd_min(
1408            self.handle,
1409            indices,
1410            ops.convert_to_tensor(updates, self.dtype),
1411            name=name))
1412
1413  def _write_object_proto(self, proto, options):
1414    """Writes additional information of the variable into the SavedObject proto.
1415
1416    Subclasses of ResourceVariables could choose to override this method to
1417    customize extra information to provide when saving a SavedModel.
1418
1419    Ideally, this should contain the logic in
1420    write_object_proto_for_resource_variable but `DistributedValue` is an
1421    outlier at the momemnt. Once `DistributedValue` becomes a proper
1422    ResourceVariable, we should remove the helper method below.
1423
1424    Args:
1425      proto: `SavedObject` proto to update.
1426      options: A `SaveOption` instance that configures save behavior.
1427    """
1428    write_object_proto_for_resource_variable(self, proto, options)
1429
1430  def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
1431                            end_mask, ellipsis_mask, new_axis_mask,
1432                            shrink_axis_mask):
1433    with _handle_graph(self.handle), self._assign_dependencies():
1434      return self._lazy_read(
1435          gen_array_ops.resource_strided_slice_assign(
1436              ref=self.handle,
1437              begin=begin,
1438              end=end,
1439              strides=strides,
1440              value=ops.convert_to_tensor(value, dtype=self.dtype),
1441              name=name,
1442              begin_mask=begin_mask,
1443              end_mask=end_mask,
1444              ellipsis_mask=ellipsis_mask,
1445              new_axis_mask=new_axis_mask,
1446              shrink_axis_mask=shrink_axis_mask))
1447
1448  def __complex__(self):
1449    return complex(self.value().numpy())
1450
1451  def __int__(self):
1452    return int(self.value().numpy())
1453
1454  def __long__(self):
1455    return long(self.value().numpy())
1456
1457  def __float__(self):
1458    return float(self.value().numpy())
1459
1460  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1461    del name
1462    if dtype is not None and not dtype.is_compatible_with(self.dtype):
1463      raise ValueError(
1464          f"Incompatible type conversion requested to type {dtype.name} for "
1465          f"`tf.Variable of type {self.dtype.name}. (Variable: {self})")
1466    if as_ref:
1467      return self.read_value().op.inputs[0]
1468    else:
1469      return self.value()
1470
1471  def __iadd__(self, unused_other):
1472    raise RuntimeError("`variable += value` with `tf.Variable`s is not "
1473                       "supported. Use `variable.assign_add(value)` to modify "
1474                       "the variable, or `out = variable + value` if you "
1475                       "need to get a new output Tensor.")
1476
1477  def __isub__(self, unused_other):
1478    raise RuntimeError("`variable -= value` with `tf.Variable`s is not "
1479                       "supported. Use `variable.assign_sub(value)` to modify "
1480                       "the variable, or `out = variable * value` if you "
1481                       "need to get a new output Tensor.")
1482
1483  def __imul__(self, unused_other):
1484    raise RuntimeError("`var *= value` with `tf.Variable`s is not "
1485                       "supported. Use `var.assign(var * value)` to modify "
1486                       "the variable, or `out = var * value` if you "
1487                       "need to get a new output Tensor.")
1488
1489  def __idiv__(self, unused_other):
1490    raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1491                       "supported. Use `var.assign(var / value)` to modify "
1492                       "the variable, or `out = var / value` if you "
1493                       "need to get a new output Tensor.")
1494
1495  def __itruediv__(self, unused_other):
1496    raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1497                       "supported. Use `var.assign(var / value)` to modify "
1498                       "the variable, or `out = var / value` if you "
1499                       "need to get a new output Tensor.")
1500
1501  def __irealdiv__(self, unused_other):
1502    raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1503                       "supported. Use `var.assign(var / value)` to modify "
1504                       "the variable, or `out = var / value` if you "
1505                       "need to get a new output Tensor.")
1506
1507  def __ipow__(self, unused_other):
1508    raise RuntimeError("`var **= value` with `tf.Variable`s is not "
1509                       "supported. Use `var.assign(var ** value)` to modify "
1510                       "the variable, or `out = var ** value` if you "
1511                       "need to get a new output Tensor.")
1512
1513
1514class ResourceVariable(BaseResourceVariable):
1515  """Variable based on resource handles.
1516
1517  See the [Variables How To](https://tensorflow.org/guide/variables)
1518  for a high level overview.
1519
1520  A `ResourceVariable` allows you to maintain state across subsequent calls to
1521  session.run.
1522
1523  The `ResourceVariable` constructor requires an initial value for the variable,
1524  which can be a `Tensor` of any type and shape. The initial value defines the
1525  type and shape of the variable. After construction, the type and shape of
1526  the variable are fixed. The value can be changed using one of the assign
1527  methods.
1528
1529  Just like any `Tensor`, variables created with
1530  `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the
1531  graph. Additionally, all the operators overloaded for the `Tensor` class are
1532  carried over to variables, so you can also add nodes to the graph by just
1533  doing arithmetic on variables.
1534
1535  Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each
1536  usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
1537  to the graph. The Tensors returned by a read_value operation are guaranteed to
1538  see all modifications to the value of the variable which happen in any
1539  operation on which the read_value depends on (either directly, indirectly, or
1540  via a control dependency) and guaranteed to not see any modification to the
1541  value of the variable from operations that depend on the read_value operation.
1542  Updates from operations that have no dependency relationship to the read_value
1543  operation might or might not be visible to read_value.
1544
1545  For example, if there is more than one assignment to a ResourceVariable in
1546  a single session.run call there is a well-defined value for each operation
1547  which uses the variable's value if the assignments and the read are connected
1548  by edges in the graph. Consider the following example, in which two writes
1549  can cause tf.Variable and tf.ResourceVariable to behave differently:
1550
1551  ```python
1552  a = tf.Variable(1.0, use_resource=True)
1553  a.initializer.run()
1554
1555  assign = a.assign(2.0)
1556  with tf.control_dependencies([assign]):
1557    b = a.read_value()
1558  with tf.control_dependencies([b]):
1559    other_assign = a.assign(3.0)
1560  with tf.control_dependencies([other_assign]):
1561    # Will print 2.0 because the value was read before other_assign ran. If
1562    # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed.
1563    tf.compat.v1.Print(b, [b]).eval()
1564  ```
1565  """
1566
1567  def __init__(
1568      self,  # pylint: disable=super-init-not-called
1569      initial_value=None,
1570      trainable=None,
1571      collections=None,
1572      validate_shape=True,  # pylint: disable=unused-argument
1573      caching_device=None,
1574      name=None,
1575      dtype=None,
1576      variable_def=None,
1577      import_scope=None,
1578      constraint=None,
1579      distribute_strategy=None,
1580      synchronization=None,
1581      aggregation=None,
1582      shape=None):
1583    """Creates a variable.
1584
1585    Args:
1586      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1587        which is the initial value for the Variable. Can also be a callable with
1588        no argument that returns the initial value when called. (Note that
1589        initializer functions from init_ops.py must first be bound to a shape
1590        before being used here.)
1591      trainable: If `True`, the default, also adds the variable to the graph
1592        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1593        the default list of variables to use by the `Optimizer` classes.
1594        Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1595        which case it defaults to `False`.
1596      collections: List of graph collections keys. The new variable is added to
1597        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1598      validate_shape: If `False`, allows the variable to be initialized with a
1599        value of unknown shape. If `True`, the default, the shape of
1600        `initial_value` must be known.
1601      caching_device: Optional device string or function describing where the
1602        Variable should be cached for reading.  Defaults to the Variable's
1603        device.  If not `None`, caches on another device.  Typical use is to
1604        cache on the device where the Ops using the Variable reside, to
1605        deduplicate copying through `Switch` and other conditional statements.
1606      name: Optional name for the variable. Defaults to `'Variable'` and gets
1607        uniquified automatically.
1608      dtype: If set, initial_value will be converted to the given type. If None,
1609        either the datatype will be kept (if initial_value is a Tensor) or
1610        float32 will be used (if it is a Python object convertible to a Tensor).
1611      variable_def: `VariableDef` protocol buffer. If not None, recreates the
1612        `ResourceVariable` object with its contents. `variable_def` and other
1613        arguments (except for import_scope) are mutually exclusive.
1614      import_scope: Optional `string`. Name scope to add to the
1615        ResourceVariable. Only used when `variable_def` is provided.
1616      constraint: An optional projection function to be applied to the variable
1617        after being updated by an `Optimizer` (e.g. used to implement norm
1618        constraints or value constraints for layer weights). The function must
1619        take as input the unprojected Tensor representing the value of the
1620        variable and return the Tensor for the projected value (which must have
1621        the same shape). Constraints are not safe to use when doing asynchronous
1622        distributed training.
1623      distribute_strategy: The tf.distribute.Strategy this variable is being
1624        created inside of.
1625      synchronization: Indicates when a distributed a variable will be
1626        aggregated. Accepted values are constants defined in the class
1627        `tf.VariableSynchronization`. By default the synchronization is set to
1628        `AUTO` and the current `DistributionStrategy` chooses when to
1629        synchronize.
1630      aggregation: Indicates how a distributed variable will be aggregated.
1631        Accepted values are constants defined in the class
1632        `tf.VariableAggregation`.
1633      shape: (optional) The shape of this variable. If None, the shape of
1634        `initial_value` will be used. When setting this argument to
1635        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1636        can be assigned with values of different shapes.
1637
1638    Raises:
1639      ValueError: If the initial value is not specified, or does not have a
1640        shape and `validate_shape` is `True`.
1641
1642    @compatibility(eager)
1643    When Eager Execution is enabled, the default for the `collections` argument
1644    is `None`, which signifies that this `Variable` will not be added to any
1645    collections.
1646    @end_compatibility
1647    """
1648    if variable_def:
1649      if initial_value is not None:
1650        raise ValueError(f"The variable_def and initial_value args to "
1651                         f"`tf.Variable` are mutually exclusive, but got both: "
1652                         f"variable_def={variable_def},\n"
1653                         f"initial_value={initial_value}")
1654      if context.executing_eagerly():
1655        raise ValueError(f"Creating a `tf.Variable` with a `variable_def` arg "
1656                         f"is not supported when eager execution is enabled. "
1657                         f"Got: variable_def={variable_def}")
1658      self._init_from_proto(
1659          variable_def,
1660          import_scope=import_scope,
1661          validate_shape=validate_shape)
1662    else:
1663      self._init_from_args(
1664          initial_value=initial_value,
1665          trainable=trainable,
1666          collections=collections,
1667          caching_device=caching_device,
1668          name=name,
1669          dtype=dtype,
1670          constraint=constraint,
1671          synchronization=synchronization,
1672          aggregation=aggregation,
1673          shape=shape,
1674          distribute_strategy=distribute_strategy,
1675          validate_shape=validate_shape,
1676      )
1677
1678  def _init_from_args(
1679      self,
1680      initial_value=None,
1681      trainable=None,
1682      collections=None,
1683      caching_device=None,
1684      name=None,
1685      dtype=None,
1686      constraint=None,
1687      synchronization=None,
1688      aggregation=None,
1689      distribute_strategy=None,
1690      shape=None,
1691      validate_shape=True,
1692  ):
1693    """Creates a variable.
1694
1695    Args:
1696      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1697        which is the initial value for the Variable. The initial value must have
1698        a shape specified unless `validate_shape` is set to False. Can also be a
1699        callable with no argument that returns the initial value when called.
1700        (Note that initializer functions from init_ops.py must first be bound to
1701        a shape before being used here.)
1702      trainable: If `True`, the default, also adds the variable to the graph
1703        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1704        the default list of variables to use by the `Optimizer` classes.
1705        Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1706        which case it defaults to `False`.
1707      collections: List of graph collections keys. The new variable is added to
1708        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1709      caching_device: Optional device string or function describing where the
1710        Variable should be cached for reading.  Defaults to the Variable's
1711        device.  If not `None`, caches on another device.  Typical use is to
1712        cache on the device where the Ops using the Variable reside, to
1713        deduplicate copying through `Switch` and other conditional statements.
1714      name: Optional name for the variable. Defaults to `'Variable'` and gets
1715        uniquified automatically.
1716      dtype: If set, initial_value will be converted to the given type. If None,
1717        either the datatype will be kept (if initial_value is a Tensor) or
1718        float32 will be used (if it is a Python object convertible to a Tensor).
1719      constraint: An optional projection function to be applied to the variable
1720        after being updated by an `Optimizer` (e.g. used to implement norm
1721        constraints or value constraints for layer weights). The function must
1722        take as input the unprojected Tensor representing the value of the
1723        variable and return the Tensor for the projected value (which must have
1724        the same shape). Constraints are not safe to use when doing asynchronous
1725        distributed training.
1726      synchronization: Indicates when a distributed a variable will be
1727        aggregated. Accepted values are constants defined in the class
1728        `tf.VariableSynchronization`. By default the synchronization is set to
1729        `AUTO` and the current `DistributionStrategy` chooses when to
1730        synchronize.
1731      aggregation: Indicates how a distributed variable will be aggregated.
1732        Accepted values are constants defined in the class
1733        `tf.VariableAggregation`.
1734      distribute_strategy: DistributionStrategy under which this variable was
1735        created.
1736      shape: (optional) The shape of this variable. If None, the shape of
1737        `initial_value` will be used. When setting this argument to
1738        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1739        can be assigned with values of different shapes.
1740      validate_shape: If `False`, allows the variable to be initialized with a
1741        value of unknown shape. If `True`, the default, the shape of
1742        `initial_value` must be known.
1743
1744    Raises:
1745      ValueError: If the initial value is not specified, or does not have a
1746        shape and `validate_shape` is `True`.
1747
1748    @compatibility(eager)
1749    When Eager Execution is enabled, variables are never added to collections.
1750    It is not implicitly added to the `GLOBAL_VARIABLES` or
1751    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
1752    ignored.
1753    @end_compatibility
1754    """
1755    synchronization, aggregation, trainable = (
1756        variables.validate_synchronization_aggregation_trainable(
1757            synchronization, aggregation, trainable, name))
1758    if initial_value is None:
1759      raise ValueError("The `initial_value` arg to `tf.Variable` must "
1760                       "be specified except when you are not providing a "
1761                       "`variable_def`. You provided neither.")
1762    init_from_fn = callable(initial_value)
1763
1764    if isinstance(initial_value, ops.Tensor) and hasattr(
1765        initial_value, "graph") and initial_value.graph.building_function:
1766      raise ValueError(f"Argument `initial_value` ({initial_value}) could not "
1767                       "be lifted out of a `tf.function`. "
1768                       f"(Tried to create variable with name='{name}'). "
1769                       "To avoid this error, when constructing `tf.Variable`s "
1770                       "inside of `tf.function` you can create the "
1771                       "`initial_value` tensor in a "
1772                       "`tf.init_scope` or pass a callable `initial_value` "
1773                       "(e.g., `tf.Variable(lambda : "
1774                       "tf.truncated_normal([10, 40]))`). "
1775                       "Please file a feature request if this "
1776                       "restriction inconveniences you.")
1777
1778    if collections is None:
1779      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
1780    if not isinstance(collections, (list, tuple, set)):
1781      raise ValueError(
1782          f"collections argument to Variable constructor must be a list, "
1783          f"tuple, or set. Got {collections} of type {type(collections)}")
1784    if constraint is not None and not callable(constraint):
1785      raise ValueError(f"Argument `constraint` must be None or a callable. "
1786                       f"a callable. Got a {type(constraint)}:  {constraint}")
1787
1788    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
1789      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
1790    with ops.init_scope():
1791      self._in_graph_mode = not context.executing_eagerly()
1792      with ops.name_scope(
1793          name,
1794          "Variable", [] if init_from_fn else [initial_value],
1795          skip_on_eager=False) as name:
1796        # pylint: disable=protected-access
1797        handle_name = ops.name_from_scope_name(name)
1798        if self._in_graph_mode:
1799          shared_name = handle_name
1800          unique_id = shared_name
1801        else:
1802          # When in eager mode use a uid for the shared_name, to prevent
1803          # accidental sharing.
1804          unique_id = "%s_%d" % (handle_name, ops.uid())
1805          shared_name = None  # Never shared
1806        # Use attr_scope and device(None) to simulate the behavior of
1807        # colocate_with when the variable we want to colocate with doesn't
1808        # yet exist.
1809        device_context_manager = (
1810            ops.device if self._in_graph_mode else ops.NullContextmanager)
1811        attr = attr_value_pb2.AttrValue(
1812            list=attr_value_pb2.AttrValue.ListValue(
1813                s=[compat.as_bytes("loc:@%s" % handle_name)]))
1814        with ops.get_default_graph()._attr_scope({"_class": attr}):
1815          with ops.name_scope("Initializer"), device_context_manager(None):
1816            if init_from_fn:
1817              initial_value = initial_value()
1818            if isinstance(initial_value, trackable.CheckpointInitialValue):
1819              self._maybe_initialize_trackable()
1820              self._update_uid = initial_value.checkpoint_position.restore_uid
1821              initial_value = initial_value.wrapped_value
1822            initial_value = ops.convert_to_tensor(
1823                initial_value, name="initial_value", dtype=dtype)
1824          if shape is not None:
1825            if not initial_value.shape.is_compatible_with(shape):
1826              raise ValueError(
1827                  f"In this `tf.Variable` creation, the initial value's shape "
1828                  f"({initial_value.shape}) is not compatible with "
1829                  f"the explicitly supplied `shape` argument ({shape}).")
1830          else:
1831            shape = initial_value.shape
1832          handle = eager_safe_variable_handle(
1833              initial_value=initial_value,
1834              shape=shape,
1835              shared_name=shared_name,
1836              name=name,
1837              graph_mode=self._in_graph_mode)
1838          handle._parent_trackable = weakref.ref(self)
1839        # pylint: disable=protected-access
1840        if (self._in_graph_mode and initial_value is not None and
1841            initial_value.op._get_control_flow_context() is not None):
1842          raise ValueError(
1843              f"The `initial_value` passed to `tf.Variable` {name} is from "
1844              f"inside a control-flow  construct, such as a loop or "
1845              f"conditional. When creating a "
1846              f"`tf.Variable` inside a loop or conditional, use a lambda as "
1847              f"the `initial_value`. Got: initial_value=({initial_value})")
1848        # pylint: enable=protected-access
1849        dtype = initial_value.dtype.base_dtype
1850
1851        if self._in_graph_mode:
1852          with ops.name_scope("IsInitialized"):
1853            is_initialized_op = (
1854                gen_resource_variable_ops.var_is_initialized_op(handle))
1855          if initial_value is not None:
1856            # pylint: disable=g-backslash-continuation
1857            with ops.name_scope("Assign") as n, \
1858                 ops.colocate_with(None, ignore_existing=True), \
1859                 ops.device(handle.device):
1860              # pylint: disable=protected-access
1861              initializer_op = (
1862                  gen_resource_variable_ops.assign_variable_op(
1863                      handle,
1864                      variables._try_guard_against_uninitialized_dependencies(
1865                          name, initial_value),
1866                      name=n))
1867              # pylint: enable=protected-access
1868            # pylint: enable=g-backslash-continuation
1869          with ops.name_scope("Read"):
1870            # Manually assign reads to the handle's device to avoid log
1871            # messages.
1872            with ops.device(handle.device):
1873              value = gen_resource_variable_ops.read_variable_op(handle, dtype)
1874              _maybe_set_handle_data(dtype, handle, value)
1875            graph_element = value
1876            if caching_device is not None:
1877              # Variables may be created in a tf.device() or ops.colocate_with()
1878              # context. At the same time, users would expect caching device to
1879              # be independent of this context, and/or would not expect the
1880              # current device context to be merged with the caching device
1881              # spec.  Therefore we reset the colocation stack before creating
1882              # the cached value. Note that resetting the colocation stack will
1883              # also reset the device stack.
1884              with ops.colocate_with(None, ignore_existing=True):
1885                with ops.device(caching_device):
1886                  cached_value = array_ops.identity(value)
1887            else:
1888              cached_value = None
1889        else:
1890          gen_resource_variable_ops.assign_variable_op(handle, initial_value)
1891          is_initialized_op = None
1892          initializer_op = None
1893          graph_element = None
1894          if caching_device:
1895            with ops.device(caching_device):
1896              cached_value = gen_resource_variable_ops.read_variable_op(
1897                  handle, dtype)
1898              _maybe_set_handle_data(dtype, handle, cached_value)
1899          else:
1900            cached_value = None
1901
1902        if cached_value is not None:
1903          # Store the variable object so that the original variable can be
1904          # accessed to generate functions that are compatible with SavedModel.
1905          cached_value._cached_variable = weakref.ref(self)  # pylint: disable=protected-access
1906
1907        if not context.executing_eagerly():
1908          # Eager variables are only added to collections if they are part of an
1909          # eager variable store (otherwise in an interactive session they would
1910          # hog memory and cause OOM). This is done in ops/variable_scope.py.
1911          ops.add_to_collections(collections, self)
1912        elif ops.GraphKeys.GLOBAL_STEP in collections:
1913          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
1914      initial_value = initial_value if self._in_graph_mode else None
1915      super(ResourceVariable, self).__init__(
1916          trainable=trainable,
1917          shape=shape,
1918          dtype=dtype,
1919          handle=handle,
1920          synchronization=synchronization,
1921          constraint=constraint,
1922          aggregation=aggregation,
1923          distribute_strategy=distribute_strategy,
1924          name=name,
1925          unique_id=unique_id,
1926          handle_name=handle_name,
1927          graph_element=graph_element,
1928          initial_value=initial_value,
1929          initializer_op=initializer_op,
1930          is_initialized_op=is_initialized_op,
1931          cached_value=cached_value,
1932          caching_device=caching_device,
1933          validate_shape=validate_shape,
1934      )
1935
1936  def _init_from_proto(self,
1937                       variable_def,
1938                       import_scope=None,
1939                       validate_shape=True):
1940    """Initializes from `VariableDef` proto."""
1941    # Note that init_from_proto is currently not supported in Eager mode.
1942    assert not context.executing_eagerly()
1943    self._in_graph_mode = True
1944    assert isinstance(variable_def, variable_pb2.VariableDef)
1945    if not variable_def.is_resource:
1946      raise ValueError(f"The `variable_def` you passed to `tf.Variable` is "
1947                       f"Trying to restore a TF 1.x Reference Variable "
1948                       f"as a TF 2.x ResourceVariable. This is unsupported. "
1949                       f"Got variable_def={variable_def}")
1950
1951    # Create from variable_def.
1952    g = ops.get_default_graph()
1953    self._handle = g.as_graph_element(
1954        ops.prepend_name_scope(
1955            variable_def.variable_name, import_scope=import_scope))
1956    self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
1957    self._handle_name = self._handle.name
1958    self._unique_id = self._handle_name
1959    self._initializer_op = g.as_graph_element(
1960        ops.prepend_name_scope(
1961            variable_def.initializer_name, import_scope=import_scope))
1962    # Check whether initial_value_name exists for backwards compatibility.
1963    if (hasattr(variable_def, "initial_value_name") and
1964        variable_def.initial_value_name):
1965      self._initial_value = g.as_graph_element(
1966          ops.prepend_name_scope(
1967              variable_def.initial_value_name, import_scope=import_scope))
1968    else:
1969      self._initial_value = None
1970    synchronization, aggregation, trainable = (
1971        variables.validate_synchronization_aggregation_trainable(
1972            variable_def.synchronization, variable_def.aggregation,
1973            variable_def.trainable, variable_def.variable_name))
1974    self._synchronization = synchronization
1975    self._aggregation = aggregation
1976    self._trainable = trainable
1977    if variable_def.snapshot_name:
1978      snapshot = g.as_graph_element(
1979          ops.prepend_name_scope(
1980              variable_def.snapshot_name, import_scope=import_scope))
1981      if snapshot.op.type != "ReadVariableOp":
1982        self._cached_value = snapshot
1983      else:
1984        self._cached_value = None
1985      while snapshot.op.type != "ReadVariableOp":
1986        snapshot = snapshot.op.inputs[0]
1987      self._graph_element = snapshot
1988    else:
1989      self._cached_value = None
1990      # Legacy case for protos without the snapshot name; assume it's the
1991      # following.
1992      self._graph_element = g.get_tensor_by_name(self._handle.op.name +
1993                                                 "/Read/ReadVariableOp:0")
1994    if variable_def.HasField("save_slice_info_def"):
1995      self._save_slice_info = variables.Variable.SaveSliceInfo(
1996          save_slice_info_def=variable_def.save_slice_info_def,
1997          import_scope=import_scope)
1998    else:
1999      self._save_slice_info = None
2000    self._caching_device = None
2001    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
2002    self._constraint = None
2003    self._validate_shape = validate_shape
2004
2005
2006class UninitializedVariable(BaseResourceVariable):
2007  """A variable with no initializer."""
2008
2009  def __init__(  # pylint: disable=super-init-not-called
2010      self,
2011      trainable=None,
2012      caching_device=None,
2013      name=None,
2014      shape=None,
2015      dtype=None,
2016      constraint=None,
2017      synchronization=None,
2018      aggregation=None,
2019      extra_handle_data=None,
2020      distribute_strategy=None,
2021      **unused_kwargs):
2022    """Creates the variable handle.
2023
2024    Args:
2025      trainable: If `True`, GradientTapes automatically watch uses of this
2026        Variable.
2027      caching_device: Optional device string or function describing where the
2028        Variable should be cached for reading.  Defaults to the Variable's
2029        device.  If not `None`, caches on another device.  Typical use is to
2030        cache on the device where the Ops using the Variable reside, to
2031        deduplicate copying through `Switch` and other conditional statements.
2032      name: Optional name for the variable. Defaults to `'Variable'` and gets
2033        uniquified automatically.
2034      shape: The variable's shape.
2035      dtype: The variable's dtype.
2036      constraint: An optional projection function to be applied to the variable
2037        after being updated by an `Optimizer` (e.g. used to implement norm
2038        constraints or value constraints for layer weights). The function must
2039        take as input the unprojected Tensor representing the value of the
2040        variable and return the Tensor for the projected value (which must have
2041        the same shape). Constraints are not safe to use when doing asynchronous
2042        distributed training.
2043      synchronization: Indicates when a distributed a variable will be
2044        aggregated. Accepted values are constants defined in the class
2045        `tf.VariableSynchronization`. By default the synchronization is set to
2046        `AUTO` and the current `DistributionStrategy` chooses when to
2047        synchronize.
2048      aggregation: Indicates how a distributed variable will be aggregated.
2049        Accepted values are constants defined in the class
2050        `tf.VariableAggregation`.
2051      extra_handle_data: Optional, another resource handle or Tensor with handle
2052        data to merge with `shape` and `dtype`.
2053      distribute_strategy: The tf.distribute.Strategy this variable is being
2054        created inside of.
2055    """
2056    with ops.init_scope():
2057      # Here we are detecting eagerness within an init_scope, so this will only
2058      # be true when we are running in TF1 graph mode.
2059      self._in_graph_mode = not context.executing_eagerly()
2060      with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
2061        handle_name = ops.name_from_scope_name(name)
2062        if self._in_graph_mode:
2063          shared_name = handle_name
2064          unique_id = shared_name
2065        else:
2066          unique_id = "%s_%d" % (handle_name, ops.uid())
2067          shared_name = None  # Never shared
2068        handle = _variable_handle_from_shape_and_dtype(
2069            shape=shape,
2070            dtype=dtype,
2071            shared_name=shared_name,
2072            name=name,
2073            graph_mode=self._in_graph_mode,
2074            initial_value=extra_handle_data)
2075        handle._parent_trackable = weakref.ref(self)
2076
2077        if self._in_graph_mode:
2078          # We only need to add the read_variable_op in TF1.
2079          with ops.name_scope("Read"):
2080            # Manually assign reads to the handle's device to avoid log
2081            # messages.
2082            with ops.device(handle.device):
2083              value = gen_resource_variable_ops.read_variable_op(handle, dtype)
2084              _maybe_set_handle_data(dtype, handle, value)
2085            graph_element = value
2086          ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
2087          # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable,
2088          # because retraining or frozen use of imported SavedModels is
2089          # controlled at higher levels of model building.
2090        else:
2091          graph_element = None
2092    super(UninitializedVariable, self).__init__(
2093        distribute_strategy=distribute_strategy,
2094        shape=shape,
2095        dtype=dtype,
2096        unique_id=unique_id,
2097        handle_name=handle_name,
2098        constraint=constraint,
2099        handle=handle,
2100        graph_element=graph_element,
2101        trainable=trainable,
2102        synchronization=synchronization,
2103        aggregation=aggregation,
2104        in_graph_mode=self._in_graph_mode)
2105
2106
2107_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable)
2108math_ops._resource_variable_type = ResourceVariable  # pylint: disable=protected-access
2109
2110
2111def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
2112  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
2113
2114
2115# Register a conversion function which reads the value of the variable,
2116# allowing instances of the class to be used as tensors.
2117ops.register_tensor_conversion_function(BaseResourceVariable,
2118                                        _dense_var_to_tensor)
2119
2120
2121class _UnreadVariable(BaseResourceVariable):
2122  """Represents a future for a read of a variable.
2123
2124  Pretends to be the tensor if anyone looks.
2125  """
2126
2127  def __init__(self, handle, dtype, shape, in_graph_mode, parent_op, unique_id):
2128    if isinstance(handle, ops.EagerTensor):
2129      handle_name = ""
2130    else:
2131      handle_name = handle.name
2132    # Only create a graph_element if we're in session.run-land as only
2133    # session.run requires a preexisting tensor to evaluate. Otherwise we can
2134    # avoid accidentally reading the variable.
2135    if context.executing_eagerly() or ops.inside_function():
2136      graph_element = None
2137    else:
2138      with ops.control_dependencies([parent_op]):
2139        graph_element = gen_resource_variable_ops.read_variable_op(
2140            handle, dtype)
2141        _maybe_set_handle_data(dtype, handle, graph_element)
2142    super(_UnreadVariable, self).__init__(
2143        handle=handle,
2144        shape=shape,
2145        handle_name=handle_name,
2146        unique_id=unique_id,
2147        dtype=dtype,
2148        graph_element=graph_element)
2149    self._parent_op = parent_op
2150
2151  @property
2152  def name(self):
2153    if self._in_graph_mode:
2154      return self._parent_op.name
2155    else:
2156      return "UnreadVariable"
2157
2158  def value(self):
2159    return self._read_variable_op()
2160
2161  def read_value(self):
2162    return self._read_variable_op()
2163
2164  def _read_variable_op(self):
2165    with ops.control_dependencies([self._parent_op]):
2166      result = gen_resource_variable_ops.read_variable_op(
2167          self._handle, self._dtype)
2168      _maybe_set_handle_data(self._dtype, self._handle, result)
2169      return result
2170
2171  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
2172    with ops.control_dependencies([self._parent_op]):
2173      return super(_UnreadVariable, self).assign_sub(delta, use_locking, name,
2174                                                     read_value)
2175
2176  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
2177    with ops.control_dependencies([self._parent_op]):
2178      return super(_UnreadVariable, self).assign_add(delta, use_locking, name,
2179                                                     read_value)
2180
2181  def assign(self, value, use_locking=None, name=None, read_value=True):
2182    with ops.control_dependencies([self._parent_op]):
2183      return super(_UnreadVariable, self).assign(value, use_locking, name,
2184                                                 read_value)
2185
2186  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
2187    with ops.control_dependencies([self._parent_op]):
2188      return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking,
2189                                                      name)
2190
2191  def scatter_add(self, sparse_delta, use_locking=False, name=None):
2192    with ops.control_dependencies([self._parent_op]):
2193      return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking,
2194                                                      name)
2195
2196  def scatter_max(self, sparse_delta, use_locking=False, name=None):
2197    with ops.control_dependencies([self._parent_op]):
2198      return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking,
2199                                                      name)
2200
2201  def scatter_min(self, sparse_delta, use_locking=False, name=None):
2202    with ops.control_dependencies([self._parent_op]):
2203      return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking,
2204                                                      name)
2205
2206  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
2207    with ops.control_dependencies([self._parent_op]):
2208      return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking,
2209                                                      name)
2210
2211  def scatter_div(self, sparse_delta, use_locking=False, name=None):
2212    with ops.control_dependencies([self._parent_op]):
2213      return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking,
2214                                                      name)
2215
2216  def scatter_update(self, sparse_delta, use_locking=False, name=None):
2217    with ops.control_dependencies([self._parent_op]):
2218      return super(_UnreadVariable,
2219                   self).scatter_update(sparse_delta, use_locking, name)
2220
2221  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
2222    with ops.control_dependencies([self._parent_op]):
2223      return super(_UnreadVariable,
2224                   self).batch_scatter_update(sparse_delta, use_locking, name)
2225
2226  def scatter_nd_sub(self, indices, updates, name=None):
2227    with ops.control_dependencies([self._parent_op]):
2228      return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name)
2229
2230  def scatter_nd_add(self, indices, updates, name=None):
2231    with ops.control_dependencies([self._parent_op]):
2232      return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name)
2233
2234  def scatter_nd_update(self, indices, updates, name=None):
2235    with ops.control_dependencies([self._parent_op]):
2236      return super(_UnreadVariable,
2237                   self).scatter_nd_update(indices, updates, name)
2238
2239  def scatter_nd_max(self, indices, updates, name=None):
2240    with ops.control_dependencies([self._parent_op]):
2241      return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name)
2242
2243  def scatter_nd_min(self, indices, updates, name=None):
2244    with ops.control_dependencies([self._parent_op]):
2245      return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name)
2246
2247  @property
2248  def op(self):
2249    """The op for this variable."""
2250    return self._parent_op
2251
2252
2253@ops.RegisterGradient("ReadVariableOp")
2254def _ReadGrad(_, grad):
2255  """Gradient for read op."""
2256  return grad
2257
2258
2259def variable_shape(handle, out_type=dtypes.int32):
2260  handle_data = get_eager_safe_handle_data(handle)
2261  if handle_data is None or not handle_data.is_set:
2262    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2263  shape_proto = handle_data.shape_and_type[0].shape
2264  if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim):
2265    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2266  return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type)
2267
2268
2269@ops.RegisterGradient("ResourceGather")
2270def _GatherGrad(op, grad):
2271  """Gradient for gather op."""
2272  # Build appropriately shaped IndexedSlices
2273  handle = op.inputs[0]
2274  indices = op.inputs[1]
2275  params_shape = variable_shape(handle)
2276  size = array_ops.expand_dims(array_ops.size(indices), 0)
2277  values_shape = array_ops.concat([size, params_shape[1:]], 0)
2278  values = array_ops.reshape(grad, values_shape)
2279  indices = array_ops.reshape(indices, size)
2280  return (indexed_slices.IndexedSlices(values, indices, params_shape), None)
2281
2282
2283def _to_proto_fn(v, export_scope=None):
2284  """Converts Variable and ResourceVariable to VariableDef for collections."""
2285  return v.to_proto(export_scope=export_scope)
2286
2287
2288def _from_proto_fn(v, import_scope=None):
2289  """Creates Variable or ResourceVariable from VariableDef as needed."""
2290  if v.is_resource:
2291    return ResourceVariable.from_proto(v, import_scope=import_scope)
2292  return variables.Variable.from_proto(v, import_scope=import_scope)
2293
2294
2295ops.register_proto_function(
2296    ops.GraphKeys.GLOBAL_VARIABLES,
2297    proto_type=variable_pb2.VariableDef,
2298    to_proto=_to_proto_fn,
2299    from_proto=_from_proto_fn)
2300ops.register_proto_function(
2301    ops.GraphKeys.TRAINABLE_VARIABLES,
2302    proto_type=variable_pb2.VariableDef,
2303    to_proto=_to_proto_fn,
2304    from_proto=_from_proto_fn)
2305ops.register_proto_function(
2306    ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
2307    proto_type=variable_pb2.VariableDef,
2308    to_proto=_to_proto_fn,
2309    from_proto=_from_proto_fn)
2310ops.register_proto_function(
2311    ops.GraphKeys.LOCAL_VARIABLES,
2312    proto_type=variable_pb2.VariableDef,
2313    to_proto=_to_proto_fn,
2314    from_proto=_from_proto_fn)
2315ops.register_proto_function(
2316    ops.GraphKeys.MODEL_VARIABLES,
2317    proto_type=variable_pb2.VariableDef,
2318    to_proto=_to_proto_fn,
2319    from_proto=_from_proto_fn)
2320ops.register_proto_function(
2321    ops.GraphKeys.GLOBAL_STEP,
2322    proto_type=variable_pb2.VariableDef,
2323    to_proto=_to_proto_fn,
2324    from_proto=_from_proto_fn)
2325ops.register_proto_function(
2326    ops.GraphKeys.METRIC_VARIABLES,
2327    proto_type=variable_pb2.VariableDef,
2328    to_proto=_to_proto_fn,
2329    from_proto=_from_proto_fn)
2330
2331
2332@tf_export("__internal__.ops.is_resource_variable", v1=[])
2333def is_resource_variable(var):
2334  """"Returns True if `var` is to be considered a ResourceVariable."""
2335  return isinstance(var, BaseResourceVariable) or hasattr(
2336      var, "_should_act_as_resource_variable")
2337
2338
2339def copy_to_graph_uninitialized(var):
2340  """Copies an existing variable to a new graph, with no initializer."""
2341  # Like ResourceVariable.__deepcopy__, but does not set an initializer on the
2342  # new variable.
2343  # pylint: disable=protected-access
2344  new_variable = UninitializedVariable(
2345      trainable=var.trainable,
2346      constraint=var._constraint,
2347      shape=var.shape,
2348      dtype=var.dtype,
2349      name=var._shared_name,
2350      synchronization=var.synchronization,
2351      aggregation=var.aggregation,
2352      extra_handle_data=var.handle)
2353  new_variable._maybe_initialize_trackable()
2354  # pylint: enable=protected-access
2355  return new_variable
2356
2357
2358ops.NotDifferentiable("Assert")
2359ops.NotDifferentiable("VarIsInitializedOp")
2360ops.NotDifferentiable("VariableShape")
2361
2362
2363class VariableSpec(tensor_spec.DenseSpec):
2364  """Describes a tf.Variable."""
2365
2366  __slots__ = ["trainable"]
2367
2368  value_type = property(lambda self: BaseResourceVariable)
2369
2370  def __init__(self, shape, dtype=dtypes.float32, trainable=True):
2371    super(VariableSpec, self).__init__(shape, dtype=dtype)
2372    self.trainable = trainable
2373
2374  def is_compatible_with(self, spec_or_value):
2375    return (isinstance(spec_or_value, (type(self), self.value_type)) and
2376            self.shape.is_compatible_with(spec_or_value.shape) and
2377            self.dtype == spec_or_value.dtype and
2378            self.trainable == spec_or_value.trainable)
2379
2380  @classmethod
2381  def from_value(cls, value):
2382    return cls(value.shape, dtype=value.dtype, trainable=value.trainable)
2383
2384  def _to_components(self, value):
2385    return value.handle
2386
2387  def _from_components(self, components):
2388    return BaseResourceVariable(
2389        trainable=self.trainable,
2390        shape=self.shape,
2391        dtype=self.dtype,
2392        handle=components)
2393
2394  @property
2395  def _component_specs(self):
2396    return tensor_spec.TensorSpec(self.shape, dtypes.resource)
2397
2398  def _from_compatible_tensor_list(self, tensor_list):
2399    assert len(tensor_list) == 1
2400    return tensor_list[0]
2401
2402  def _serialize(self):
2403    return self.shape, self.dtype, self.trainable
2404
2405  def __tf_tracing_type__(self, signature_context):
2406    return signature_context.make_reference_type(self, id(self))
2407
2408  def __repr__(self):
2409    return (f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype}, "
2410            f"trainable={self.trainable})")
2411
2412  def __hash__(self):
2413    return hash((self.shape, self.dtype, self.trainable))
2414
2415  def __eq__(self, other):
2416    return (type(self) is type(other) and self.shape == other.shape and
2417            self.dtype == other.dtype and self.trainable == other.trainable)
2418
2419
2420_pywrap_utils.RegisterType("VariableSpec", VariableSpec)
2421
2422
2423def write_object_proto_for_resource_variable(resource_variable,
2424                                             proto,
2425                                             options,
2426                                             enforce_naming=True):
2427  """Writes additional information of the variable into the SavedObject proto.
2428
2429  This allows users to define a `hook` to provide extra information of the
2430  variable to the SavedObject.
2431
2432  For example, DistributedVariable class would fill in components in the
2433  distributed context.
2434
2435  Args:
2436    resource_variable: A `ResourceVariable` or `DistributedValue` that has the
2437      information to be saved into the proto.
2438    proto: `SavedObject` proto to update.
2439    options: A `SaveOption` instance that configures save behavior.
2440    enforce_naming: A bool determining whether to check that names end in the
2441      expected string ':0'
2442  """
2443  proto.variable.SetInParent()
2444  if enforce_naming and not resource_variable.name.endswith(":0"):
2445    raise ValueError(f"Cowardly refusing to save variable "
2446                     f"{resource_variable.name} because of "
2447                     f"unexpected suffix in the name (expected ':0')"
2448                     f"which won't be restored.")
2449  proto.variable.name = meta_graph._op_name(resource_variable.name)  # pylint: disable=protected-access
2450  proto.variable.trainable = resource_variable.trainable
2451  proto.variable.dtype = resource_variable.dtype.as_datatype_enum
2452  proto.variable.synchronization = resource_variable.synchronization.value
2453  proto.variable.aggregation = resource_variable.aggregation.value
2454  proto.variable.shape.CopyFrom(resource_variable.shape.as_proto())
2455  if options.experimental_variable_policy._save_variable_devices(  # pylint: disable=protected-access
2456  ):
2457    if hasattr(resource_variable, "device"):
2458      proto.variable.device = resource_variable.device
2459