1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to use variables as resources.""" 16 17# pylint: disable=g-bad-name 18import contextlib 19import functools 20import weakref 21 22import numpy as np 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.core.framework import variable_pb2 26from tensorflow.python.client import pywrap_tf_session 27from tensorflow.python.compat import compat as forward_compat 28from tensorflow.python.eager import context 29from tensorflow.python.eager import tape 30from tensorflow.python.framework import auto_control_deps_utils as acd 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import cpp_shape_inference_pb2 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import indexed_slices 36from tensorflow.python.framework import meta_graph 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import gen_array_ops 42from tensorflow.python.ops import gen_resource_variable_ops 43from tensorflow.python.ops import gen_state_ops 44from tensorflow.python.ops import handle_data_util 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import state_ops 47from tensorflow.python.ops import variables 48# go/tf-wildcard-import 49# pylint: disable=wildcard-import 50from tensorflow.python.ops.gen_resource_variable_ops import * 51# pylint: enable=wildcard-import 52from tensorflow.python.trackable import base as trackable 53from tensorflow.python.types import core 54from tensorflow.python.util import _pywrap_utils 55from tensorflow.python.util import compat 56from tensorflow.python.util.deprecation import deprecated 57from tensorflow.python.util.tf_export import tf_export 58 59acd.register_read_only_resource_op("ReadVariableOp") 60acd.register_read_only_resource_op("VariableShape") 61acd.register_read_only_resource_op("ResourceGather") 62acd.register_read_only_resource_op("ResourceGatherNd") 63acd.register_read_only_resource_op("_ReadVariablesOp") 64 65# TODO(allenl): Remove this alias and migrate callers. 66get_resource_handle_data = handle_data_util.get_resource_handle_data 67 68 69def get_eager_safe_handle_data(handle): 70 """Get the data handle from the Tensor `handle`.""" 71 assert isinstance(handle, ops.Tensor) 72 73 if isinstance(handle, ops.EagerTensor): 74 return handle._handle_data # pylint: disable=protected-access 75 else: 76 return get_resource_handle_data(handle) 77 78 79def _set_handle_shapes_and_types(tensor, handle_data, graph_mode): 80 """Sets the shape inference result HandleData on tensor. 81 82 Args: 83 tensor: A `Tensor` or `EagerTensor`. 84 handle_data: A `CppShapeInferenceResult.HandleData`. 85 graph_mode: A python bool. 86 """ 87 tensor._handle_data = handle_data # pylint: disable=protected-access 88 if not graph_mode: 89 return 90 91 # Not an EagerTensor, so a graph tensor. 92 shapes, types = zip( 93 *[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) 94 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 95 shapes = [ 96 [d.size for d in s.dim] # pylint: disable=g-complex-comprehension 97 if not s.unknown_rank else None for s in shapes 98 ] 99 with tensor._op._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 100 pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 101 c_graph, 102 tensor._as_tf_output(), # pylint: disable=protected-access 103 shapes, 104 ranks, 105 types) 106 107 108def _combine_handle_data(handle, initial_value): 109 """Concats HandleData from tensors `handle` and `initial_value`. 110 111 Args: 112 handle: A `Tensor` of dtype `resource`. 113 initial_value: A `Tensor`. 114 115 Returns: 116 A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype 117 `variant`, the `HandleData` contains the concatenation of the shape_and_type 118 from both `handle` and `initial_value`. 119 120 Raises: 121 RuntimeError: If handle, which was returned by VarHandleOp, either has 122 no handle data, or its len(handle_data.shape_and_type) != 1. 123 """ 124 assert handle.dtype == dtypes.resource 125 126 variable_handle_data = get_eager_safe_handle_data(handle) 127 128 if initial_value.dtype != dtypes.variant: 129 return variable_handle_data 130 131 extra_handle_data = get_eager_safe_handle_data(initial_value) 132 if extra_handle_data is not None and extra_handle_data.is_set: 133 if (variable_handle_data is None or not variable_handle_data.is_set or 134 len(variable_handle_data.shape_and_type) != 1): 135 raise RuntimeError( 136 "Expected VarHandleOp to return a length==1 shape_and_type, " 137 f"but saw: '{variable_handle_data}'") 138 variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 139 return variable_handle_data 140 141 142def _variable_handle_from_shape_and_dtype(shape, 143 dtype, 144 shared_name, 145 name, 146 graph_mode, 147 initial_value=None): 148 """Create a variable handle, copying in handle data from `initial_value`.""" 149 container = ops.get_default_graph()._container # pylint: disable=protected-access 150 if container is None: 151 container = "" 152 shape = tensor_shape.as_shape(shape) 153 dtype = dtypes.as_dtype(dtype) 154 if not graph_mode: 155 if shared_name is not None: 156 raise errors.InternalError( 157 node_def=None, 158 op=None, 159 message="Using an explicit shared_name is " 160 "not allowed when executing eagerly.") 161 shared_name = context.anonymous_name() 162 163 handle = gen_resource_variable_ops.var_handle_op( 164 shape=shape, 165 dtype=dtype, 166 shared_name=shared_name, 167 name=name, 168 container=container) 169 if initial_value is None: 170 initial_value = handle 171 if graph_mode: 172 full_handle_data = _combine_handle_data(handle, initial_value) 173 _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) 174 return handle 175 else: 176 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 177 handle_data.is_set = True 178 handle_data.shape_and_type.append( 179 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 180 shape=shape.as_proto(), dtype=dtype.as_datatype_enum)) 181 182 if initial_value is not None and initial_value.dtype == dtypes.variant: 183 extra_handle_data = get_eager_safe_handle_data(initial_value) 184 if extra_handle_data is not None and extra_handle_data.is_set: 185 if (not handle_data.is_set or len(handle_data.shape_and_type) != 1): 186 raise RuntimeError( 187 "Expected VarHandleOp to return a length==1 shape_and_type, " 188 f"but saw: '{handle_data}'") 189 handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 190 191 _set_handle_shapes_and_types(handle, handle_data, graph_mode) 192 return handle 193 194 195def eager_safe_variable_handle(initial_value, shape, shared_name, name, 196 graph_mode): 197 """Creates a variable handle with information to do shape inference. 198 199 The dtype is read from `initial_value` and stored in the returned 200 resource tensor's handle data. 201 202 If `initial_value.dtype == tf.variant`, we additionally extract the handle 203 data (if any) from `initial_value` and append it to the `handle_data`. 204 In this case, the returned tensor's handle data is in the form 205 206 ``` 207 is_set: true 208 shape_and_type { 209 shape { 210 // initial_value.shape 211 } 212 dtype: DT_VARIANT 213 } 214 shape_and_type { 215 // handle_data(initial_value).shape_and_type[0] 216 } 217 shape_and_type { 218 // handle_data(initial_value).shape_and_type[1] 219 } 220 ... 221 ``` 222 223 Ops that read from this tensor, such as `ReadVariableOp` and 224 `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]` 225 correspond to the handle data of the variant(s) stored in the Variable. 226 227 Args: 228 initial_value: A `Tensor`. 229 shape: The shape of the handle data. Can be `TensorShape(None)` (i.e. 230 unknown shape). 231 shared_name: A string. 232 name: A string. 233 graph_mode: A python bool. 234 235 Returns: 236 The handle, a `Tensor` of type `resource`. 237 """ 238 dtype = initial_value.dtype.base_dtype 239 return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name, 240 graph_mode, initial_value) 241 242 243@contextlib.contextmanager 244def _handle_graph(handle): 245 # Note: might have an eager tensor but not be executing eagerly when building 246 # functions. 247 if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or 248 ops.has_default_graph()): 249 yield 250 else: 251 with handle.graph.as_default(): 252 yield 253 254 255class EagerResourceDeleter: 256 """An object which cleans up a resource handle. 257 258 An alternative to defining a __del__ method on an object. The intended use is 259 that ResourceVariables or other objects with resource handles will maintain a 260 single reference to this object. When the parent object is collected, this 261 object will be too. Even if the parent object is part of a reference cycle, 262 the cycle will be collectable. 263 """ 264 265 __slots__ = ["_handle", "_handle_device", "_context"] 266 267 def __init__(self, handle, handle_device): 268 if not isinstance(handle, ops.Tensor): 269 raise ValueError( 270 (f"Passed handle={handle} to EagerResourceDeleter. Was expecting " 271 f"the handle to be a `tf.Tensor`.")) 272 self._handle = handle 273 self._handle_device = handle_device 274 # This is held since the __del__ function runs an op, and if the context() 275 # is collected before this object, there will be a segfault when running the 276 # op. 277 self._context = context.context() 278 279 def __del__(self): 280 # Resources follow object-identity when executing eagerly, so it is safe to 281 # delete the resource we have a handle to. 282 try: 283 # A packed EagerTensor doesn't own any resource. 284 if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed: 285 return 286 # This resource was created in eager mode. However, this destructor may be 287 # running in graph mode (especially during unit tests). To clean up 288 # successfully, we switch back into eager mode temporarily. 289 with context.eager_mode(): 290 with ops.device(self._handle_device): 291 gen_resource_variable_ops.destroy_resource_op( 292 self._handle, ignore_lookup_error=True) 293 except TypeError: 294 # Suppress some exceptions, mainly for the case when we're running on 295 # module deletion. Things that can go wrong include the context module 296 # already being unloaded, self._handle._handle_data no longer being 297 # valid, and so on. Printing warnings in these cases is silly 298 # (exceptions raised from __del__ are printed as warnings to stderr). 299 pass # 'NoneType' object is not callable when the handle has been 300 # partially unloaded. 301 except AttributeError: 302 pass # 'NoneType' object has no attribute 'eager_mode' when context has 303 # been unloaded. Will catch other module unloads as well. 304 305 306def shape_safe_assign_variable_handle(handle, shape, value, name=None): 307 """Helper that checks shape compatibility and assigns variable.""" 308 with _handle_graph(handle): 309 value_tensor = ops.convert_to_tensor(value) 310 shape.assert_is_compatible_with(value_tensor.shape) 311 return gen_resource_variable_ops.assign_variable_op( 312 handle, value_tensor, name=name) 313 314 315def _maybe_set_handle_data(dtype, handle, tensor): 316 if dtype == dtypes.variant: 317 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 318 # variant's handle data. Extract it. 319 handle_data = get_eager_safe_handle_data(handle) 320 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 321 tensor._handle_data = ( # pylint: disable=protected-access 322 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 323 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 324 325 326def variable_accessed(variable): 327 """Records that `variable` was accessed for the tape and FuncGraph.""" 328 if hasattr(ops.get_default_graph(), "watch_variable"): 329 ops.get_default_graph().watch_variable(variable) 330 if variable.trainable: 331 tape.variable_accessed(variable) 332 333 334class BaseResourceVariable(variables.VariableV1, core.Tensor): 335 """A python variable from an existing handle.""" 336 337 # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in. 338 def __init__( # pylint: disable=super-init-not-called 339 self, 340 trainable=None, 341 shape=None, 342 dtype=None, 343 handle=None, 344 constraint=None, 345 synchronization=None, 346 aggregation=None, 347 distribute_strategy=None, 348 name=None, 349 unique_id=None, 350 handle_name=None, 351 graph_element=None, 352 initial_value=None, 353 initializer_op=None, 354 is_initialized_op=None, 355 cached_value=None, 356 save_slice_info=None, 357 caching_device=None, 358 in_graph_mode=None, 359 validate_shape=True, 360 **unused_kwargs): 361 """Creates a variable from a handle. 362 363 Args: 364 trainable: If `True`, GradientTapes automatically watch uses of this 365 Variable. 366 shape: The variable's shape. This shape can be set to tf.TensorShape(None) 367 in order to assign values of different shapes to this variable. 368 Otherwise (i.e. if the shape is fully determined), it will trigger run 369 time checks to ensure that each assignment is of the same shape. 370 dtype: The variable's dtype. 371 handle: The variable's handle 372 constraint: An optional projection function to be applied to the variable 373 after being updated by an `Optimizer` (e.g. used to implement norm 374 constraints or value constraints for layer weights). The function must 375 take as input the unprojected Tensor representing the value of the 376 variable and return the Tensor for the projected value (which must have 377 the same shape). Constraints are not safe to use when doing asynchronous 378 distributed training. 379 synchronization: Indicates when a distributed a variable will be 380 aggregated. Accepted values are constants defined in the class 381 `tf.VariableSynchronization`. By default the synchronization is set to 382 `AUTO` and the current `DistributionStrategy` chooses when to 383 synchronize. 384 aggregation: Indicates how a distributed variable will be aggregated. 385 Accepted values are constants defined in the class 386 `tf.VariableAggregation`. 387 distribute_strategy: The distribution strategy this variable was created 388 under. 389 name: The name for this variable. 390 unique_id: Internal. Unique ID for this variable's handle. 391 handle_name: The name for the variable's handle. 392 graph_element: Optional, required only in session.run-mode. Pre-created 393 tensor which reads this variable's value. 394 initial_value: Optional. Variable's initial value. 395 initializer_op: Operation which assigns the variable's initial value. 396 is_initialized_op: Pre-created operation to check whether this variable is 397 initialized. 398 cached_value: Pre-created operation to read this variable in a specific 399 device. 400 save_slice_info: Metadata for variable partitioning. 401 caching_device: Optional device string or function describing where the 402 Variable should be cached for reading. Defaults to the Variable's 403 device. If not `None`, caches on another device. Typical use is to 404 cache on the device where the Ops using the Variable reside, to 405 deduplicate copying through `Switch` and other conditional statements. 406 in_graph_mode: whether we are executing in TF1 graph mode. If None, will 407 detect within the function. This is to avoid repeated init_scope() 408 conetxt entrances which can add up. 409 validate_shape: If `False`, allows the variable to be initialized with a 410 value of unknown shape. If `True`, the default, the shape of 411 `initial_value` must be known. 412 """ 413 if in_graph_mode is None: 414 with ops.init_scope(): 415 self._in_graph_mode = not context.executing_eagerly() 416 else: 417 self._in_graph_mode = in_graph_mode 418 synchronization, aggregation, trainable = ( 419 variables.validate_synchronization_aggregation_trainable( 420 synchronization, aggregation, trainable, name)) 421 self._trainable = trainable 422 self._synchronization = synchronization 423 self._aggregation = aggregation 424 self._save_slice_info = save_slice_info 425 self._initial_value = initial_value 426 self._initializer_op = initializer_op 427 self._is_initialized_op = is_initialized_op 428 self._graph_element = graph_element 429 self._caching_device = caching_device 430 self._cached_value = cached_value 431 self._distribute_strategy = distribute_strategy 432 # Store the graph key so optimizers know how to only retrieve variables from 433 # this graph. Guaranteed to be the same as the eager graph_key. 434 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 435 self._shape = tensor_shape.as_shape(shape) 436 self._dtype = dtypes.as_dtype(dtype) 437 self._handle = handle 438 self._unique_id = unique_id 439 if handle_name is None: 440 self._handle_name = "Variable:0" 441 else: 442 self._handle_name = handle_name + ":0" 443 self._constraint = constraint 444 self._cached_shape_as_list = None 445 self._validate_shape = validate_shape 446 447 def __repr__(self): 448 if context.executing_eagerly() and not self._in_graph_mode: 449 # If we cannot read the value for any reason (e.g. variable uninitialized 450 # during tf.function tracing), still produce a __repr__. Note that for 451 # async eager, errors due to uninitialized variables will raise in 452 # ops.value_text when the handle is resolved, so we need to keep that 453 # under the try...except if we want to suppress them. 454 try: 455 with ops.device(self.device): 456 value_text = ops.value_text(self.read_value(), is_repr=True) 457 except: # pylint: disable=bare-except 458 value_text = "numpy=<unavailable>" 459 460 return "<tf.Variable '%s' shape=%s dtype=%s, %s>" % ( 461 self.name, self.get_shape(), self.dtype.name, value_text) 462 else: 463 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 464 self.name, self.get_shape(), self.dtype.name) 465 466 def __tf_tracing_type__(self, signature_context): 467 return signature_context.make_reference_type( 468 VariableSpec(self.shape, self.dtype), self._handle._id) # pylint:disable=protected-access 469 470 @contextlib.contextmanager 471 def _assign_dependencies(self): 472 """Makes assignments depend on the cached value, if any. 473 474 This prevents undefined behavior with reads not ordered wrt writes. 475 476 Yields: 477 None. 478 """ 479 if self._cached_value is not None: 480 with ops.control_dependencies([self._cached_value]): 481 yield 482 else: 483 yield 484 485 def __array__(self, dtype=None): 486 """Allows direct conversion to a numpy array. 487 488 >>> np.array(tf.Variable([1.0])) 489 array([1.], dtype=float32) 490 491 Returns: 492 The variable value as a numpy array. 493 """ 494 # You can't return `self.numpy()` here because for scalars 495 # that raises: 496 # ValueError: object __array__ method not producing an array 497 # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give 498 # the same error. The `EagerTensor` class must be doing something behind the 499 # scenes to make `np.array(tf.constant(1))` work. 500 return np.asarray(self.numpy(), dtype=dtype) 501 502 def __nonzero__(self): 503 return self.__bool__() 504 505 def __bool__(self): 506 return bool(self.read_value()) 507 508 def __copy__(self): 509 return self 510 511 def __deepcopy__(self, memo): 512 if not context.executing_eagerly(): 513 raise NotImplementedError( 514 "__deepcopy__() is only available when eager execution is enabled.") 515 copied_variable = ResourceVariable( 516 initial_value=self.read_value(), 517 trainable=self._trainable, 518 constraint=self._constraint, 519 dtype=self._dtype, 520 name=self._shared_name, 521 distribute_strategy=self._distribute_strategy, 522 synchronization=self.synchronization, 523 aggregation=self.aggregation) 524 memo[self._unique_id] = copied_variable 525 return copied_variable 526 527 @property 528 def dtype(self): 529 """The dtype of this variable.""" 530 return self._dtype 531 532 @property 533 def device(self): 534 """The device this variable is on.""" 535 return self.handle.device 536 537 @property 538 def graph(self): 539 """The `Graph` of this variable.""" 540 return self.handle.graph 541 542 @property 543 def name(self): 544 """The name of the handle for this variable.""" 545 return self._handle_name 546 547 @property 548 def shape(self): 549 """The shape of this variable.""" 550 return self._shape 551 552 def set_shape(self, shape): 553 self._shape = self._shape.merge_with(shape) 554 555 def _shape_as_list(self): 556 if self.shape.ndims is None: 557 return None 558 return [dim.value for dim in self.shape.dims] 559 560 def _shape_tuple(self): 561 shape = self._shape_as_list() 562 if shape is None: 563 return None 564 return tuple(shape) 565 566 @property 567 def create(self): 568 """The op responsible for initializing this variable.""" 569 if not self._in_graph_mode: 570 raise RuntimeError("This operation is not supported " 571 "when eager execution is enabled.") 572 return self._initializer_op 573 574 @property 575 def handle(self): 576 """The handle by which this variable can be accessed.""" 577 return self._handle 578 579 def value(self): 580 """A cached operation which reads the value of this variable.""" 581 if self._cached_value is not None: 582 return self._cached_value 583 with ops.colocate_with(None, ignore_existing=True): 584 return self._read_variable_op() 585 586 def _as_graph_element(self): 587 """Conversion function for Graph.as_graph_element().""" 588 return self._graph_element 589 590 @property 591 def initializer(self): 592 """The op responsible for initializing this variable.""" 593 return self._initializer_op 594 595 @property 596 def initial_value(self): 597 """Returns the Tensor used as the initial value for the variable.""" 598 if context.executing_eagerly(): 599 raise RuntimeError("This property is not supported " 600 "when eager execution is enabled.") 601 return self._initial_value 602 603 @property 604 def constraint(self): 605 """Returns the constraint function associated with this variable. 606 607 Returns: 608 The constraint function that was passed to the variable constructor. 609 Can be `None` if no constraint was passed. 610 """ 611 return self._constraint 612 613 @property 614 def op(self): 615 """The op for this variable.""" 616 return self.handle.op 617 618 @property 619 def trainable(self): 620 return self._trainable 621 622 @property 623 def synchronization(self): 624 return self._synchronization 625 626 @property 627 def aggregation(self): 628 return self._aggregation 629 630 def eval(self, session=None): 631 """Evaluates and returns the value of this variable.""" 632 if context.executing_eagerly(): 633 raise RuntimeError("This operation is not supported " 634 "when eager execution is enabled.") 635 return self._graph_element.eval(session=session) 636 637 def numpy(self): 638 if context.executing_eagerly(): 639 return self.read_value().numpy() 640 raise NotImplementedError( 641 "numpy() is only available when eager execution is enabled.") 642 643 @deprecated(None, "Prefer Dataset.range instead.") 644 def count_up_to(self, limit): 645 """Increments this variable until it reaches `limit`. 646 647 When that Op is run it tries to increment the variable by `1`. If 648 incrementing the variable would bring it above `limit` then the Op raises 649 the exception `OutOfRangeError`. 650 651 If no error is raised, the Op outputs the value of the variable before 652 the increment. 653 654 This is essentially a shortcut for `count_up_to(self, limit)`. 655 656 Args: 657 limit: value at which incrementing the variable raises an error. 658 659 Returns: 660 A `Tensor` that will hold the variable value before the increment. If no 661 other Op modifies this variable, the values produced will all be 662 distinct. 663 """ 664 return gen_state_ops.resource_count_up_to( 665 self.handle, limit=limit, T=self.dtype) 666 667 def _map_resources(self, save_options): 668 """For implementing `Trackable`.""" 669 new_variable = None 670 if save_options.experimental_variable_policy._save_variable_devices(): # pylint:disable=protected-access 671 with ops.device(self.device): 672 new_variable = copy_to_graph_uninitialized(self) 673 else: 674 new_variable = copy_to_graph_uninitialized(self) 675 obj_map = {self: new_variable} 676 resource_map = {self.handle: new_variable.handle} 677 return obj_map, resource_map 678 679 def _read_variable_op(self, no_copy=False): 680 """Reads the value of the variable. 681 682 If the variable is in copy-on-read mode and `no_copy` is True, the variable 683 is converted to copy-on-write mode before it is read. 684 685 Args: 686 no_copy: Whether to prevent a copy of the variable. 687 688 Returns: 689 The value of the variable. 690 """ 691 variable_accessed(self) 692 693 def read_and_set_handle(no_copy): 694 if no_copy and forward_compat.forward_compatible(2022, 5, 3): 695 gen_resource_variable_ops.disable_copy_on_read(self.handle) 696 result = gen_resource_variable_ops.read_variable_op( 697 self.handle, self._dtype) 698 _maybe_set_handle_data(self._dtype, self.handle, result) 699 return result 700 701 if getattr(self, "_caching_device", None) is not None: 702 with ops.colocate_with(None, ignore_existing=True): 703 with ops.device(self._caching_device): 704 result = read_and_set_handle(no_copy) 705 else: 706 result = read_and_set_handle(no_copy) 707 708 if not context.executing_eagerly(): 709 # Note that if a control flow context is active the input of the read op 710 # might not actually be the handle. This line bypasses it. 711 tape.record_operation( 712 "ReadVariableOp", [result], [self.handle], 713 backward_function=lambda x: [x], 714 forward_function=lambda x: [x]) 715 return result 716 717 def read_value(self): 718 """Constructs an op which reads the value of this variable. 719 720 Should be used when there are multiple reads, or when it is desirable to 721 read the value only after some condition is true. 722 723 Returns: 724 The value of the variable. 725 """ 726 with ops.name_scope("Read"): 727 value = self._read_variable_op() 728 # Return an identity so it can get placed on whatever device the context 729 # specifies instead of the device where the variable is. 730 return array_ops.identity(value) 731 732 def read_value_no_copy(self): 733 """Constructs an op which reads the value of this variable without copy. 734 735 The variable is read without making a copy even when it has been sparsely 736 accessed. Variables in copy-on-read mode will be converted to copy-on-write 737 mode. 738 739 Returns: 740 The value of the variable. 741 """ 742 with ops.name_scope("Read"): 743 value = self._read_variable_op(no_copy=True) 744 # Return an identity so it can get placed on whatever device the context 745 # specifies instead of the device where the variable is. 746 return array_ops.identity(value) 747 748 def sparse_read(self, indices, name=None): 749 """Reads the value of this variable sparsely, using `gather`.""" 750 with ops.name_scope("Gather" if name is None else name) as name: 751 variable_accessed(self) 752 value = gen_resource_variable_ops.resource_gather( 753 self.handle, indices, dtype=self._dtype, name=name) 754 755 if self._dtype == dtypes.variant: 756 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 757 # variant's handle data. Extract it. 758 handle_data = get_eager_safe_handle_data(self.handle) 759 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 760 value._handle_data = ( # pylint: disable=protected-access 761 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 762 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 763 764 return array_ops.identity(value) 765 766 def gather_nd(self, indices, name=None): 767 """Reads the value of this variable sparsely, using `gather_nd`.""" 768 with ops.name_scope("GatherNd" if name is None else name) as name: 769 if self.trainable: 770 variable_accessed(self) 771 value = gen_resource_variable_ops.resource_gather_nd( 772 self.handle, indices, dtype=self._dtype, name=name) 773 774 return array_ops.identity(value) 775 776 def to_proto(self, export_scope=None): 777 """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. 778 779 Args: 780 export_scope: Optional `string`. Name scope to remove. 781 782 Raises: 783 RuntimeError: If run in EAGER mode. 784 785 Returns: 786 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 787 in the specified name scope. 788 """ 789 if context.executing_eagerly(): 790 raise RuntimeError("This operation is not supported " 791 "when eager execution is enabled.") 792 if export_scope is None or self.handle.name.startswith(export_scope): 793 var_def = variable_pb2.VariableDef() 794 var_def.variable_name = ops.strip_name_scope(self.handle.name, 795 export_scope) 796 if self._initial_value is not None: 797 # This is inside an if-statement for backwards compatibility, since 798 # self._initial_value might be None for variables constructed from old 799 # protos. 800 var_def.initial_value_name = ops.strip_name_scope( 801 self._initial_value.name, export_scope) 802 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 803 export_scope) 804 if self._cached_value is not None: 805 var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, 806 export_scope) 807 else: 808 # Store the graph_element here 809 var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, 810 export_scope) 811 var_def.is_resource = True 812 var_def.trainable = self.trainable 813 var_def.synchronization = self.synchronization.value 814 var_def.aggregation = self.aggregation.value 815 if self._save_slice_info: 816 var_def.save_slice_info_def.MergeFrom( 817 self._save_slice_info.to_proto(export_scope=export_scope)) 818 return var_def 819 else: 820 return None 821 822 @staticmethod 823 def from_proto(variable_def, import_scope=None): 824 if context.executing_eagerly(): 825 raise RuntimeError("This operation is not supported " 826 "when eager execution is enabled.") 827 return ResourceVariable( 828 variable_def=variable_def, import_scope=import_scope) 829 830 __array_priority__ = 100 831 832 def is_initialized(self, name=None): 833 """Checks whether a resource variable has been initialized. 834 835 Outputs boolean scalar indicating whether the tensor has been initialized. 836 837 Args: 838 name: A name for the operation (optional). 839 840 Returns: 841 A `Tensor` of type `bool`. 842 """ 843 return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) 844 845 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 846 """Subtracts a value from this variable. 847 848 Args: 849 delta: A `Tensor`. The value to subtract from this variable. 850 use_locking: If `True`, use locking during the operation. 851 name: The name to use for the operation. 852 read_value: A `bool`. Whether to read and return the new value of the 853 variable or not. 854 855 Returns: 856 If `read_value` is `True`, this method will return the new value of the 857 variable after the assignment has completed. Otherwise, when in graph mode 858 it will return the `Operation` that does the assignment, and when in eager 859 mode it will return `None`. 860 """ 861 # TODO(apassos): this here and below is not atomic. Consider making it 862 # atomic if there's a way to do so without a performance cost for those who 863 # don't need it. 864 with _handle_graph(self.handle), self._assign_dependencies(): 865 assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( 866 self.handle, 867 ops.convert_to_tensor(delta, dtype=self.dtype), 868 name=name) 869 if read_value: 870 return self._lazy_read(assign_sub_op) 871 return assign_sub_op 872 873 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 874 """Adds a value to this variable. 875 876 Args: 877 delta: A `Tensor`. The value to add to this variable. 878 use_locking: If `True`, use locking during the operation. 879 name: The name to use for the operation. 880 read_value: A `bool`. Whether to read and return the new value of the 881 variable or not. 882 883 Returns: 884 If `read_value` is `True`, this method will return the new value of the 885 variable after the assignment has completed. Otherwise, when in graph mode 886 it will return the `Operation` that does the assignment, and when in eager 887 mode it will return `None`. 888 """ 889 with _handle_graph(self.handle), self._assign_dependencies(): 890 assign_add_op = gen_resource_variable_ops.assign_add_variable_op( 891 self.handle, 892 ops.convert_to_tensor(delta, dtype=self.dtype), 893 name=name) 894 if read_value: 895 return self._lazy_read(assign_add_op) 896 return assign_add_op 897 898 def _lazy_read(self, op): 899 variable_accessed(self) 900 return _UnreadVariable( 901 handle=self.handle, 902 dtype=self.dtype, 903 shape=self._shape, 904 in_graph_mode=self._in_graph_mode, 905 parent_op=op, 906 unique_id=self._unique_id) 907 908 def assign(self, value, use_locking=None, name=None, read_value=True): 909 """Assigns a new value to this variable. 910 911 Args: 912 value: A `Tensor`. The new value for this variable. 913 use_locking: If `True`, use locking during the assignment. 914 name: The name to use for the assignment. 915 read_value: A `bool`. Whether to read and return the new value of the 916 variable or not. 917 918 Returns: 919 If `read_value` is `True`, this method will return the new value of the 920 variable after the assignment has completed. Otherwise, when in graph mode 921 it will return the `Operation` that does the assignment, and when in eager 922 mode it will return `None`. 923 """ 924 # Note: not depending on the cached value here since this can be used to 925 # initialize the variable. 926 with _handle_graph(self.handle): 927 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) 928 if not self._shape.is_compatible_with(value_tensor.shape): 929 if self.name is None: 930 tensor_name = "" 931 else: 932 tensor_name = " " + str(self.name) 933 raise ValueError( 934 (f"Cannot assign value to variable '{tensor_name}': Shape mismatch." 935 f"The variable shape {self._shape}, and the " 936 f"assigned value shape {value_tensor.shape} are incompatible.")) 937 kwargs = {} 938 if forward_compat.forward_compatible(2022, 3, 23): 939 # If the shape is fully defined, we do a runtime check with the shape of 940 # value. 941 validate_shape = self._validate_shape and self._shape.is_fully_defined() 942 kwargs["validate_shape"] = validate_shape 943 assign_op = gen_resource_variable_ops.assign_variable_op( 944 self.handle, value_tensor, name=name, **kwargs) 945 if read_value: 946 return self._lazy_read(assign_op) 947 return assign_op 948 949 def __reduce__(self): 950 # The implementation mirrors that of __deepcopy__. 951 return functools.partial( 952 ResourceVariable, 953 initial_value=self.numpy(), 954 trainable=self.trainable, 955 name=self._shared_name, 956 dtype=self.dtype, 957 constraint=self.constraint, 958 distribute_strategy=self._distribute_strategy), () 959 960 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 961 """Subtracts `tf.IndexedSlices` from this variable. 962 963 Args: 964 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 965 use_locking: If `True`, use locking during the operation. 966 name: the name of the operation. 967 968 Returns: 969 The updated variable. 970 971 Raises: 972 TypeError: if `sparse_delta` is not an `IndexedSlices`. 973 """ 974 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 975 raise TypeError(f"Argument `sparse_delta` must be a " 976 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 977 return self._lazy_read( 978 gen_resource_variable_ops.resource_scatter_sub( 979 self.handle, 980 sparse_delta.indices, 981 ops.convert_to_tensor(sparse_delta.values, self.dtype), 982 name=name)) 983 984 def scatter_add(self, sparse_delta, use_locking=False, name=None): 985 """Adds `tf.IndexedSlices` to this variable. 986 987 Args: 988 sparse_delta: `tf.IndexedSlices` to be added to this variable. 989 use_locking: If `True`, use locking during the operation. 990 name: the name of the operation. 991 992 Returns: 993 The updated variable. 994 995 Raises: 996 TypeError: if `sparse_delta` is not an `IndexedSlices`. 997 """ 998 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 999 raise TypeError(f"Argument `sparse_delta` must be a " 1000 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1001 return self._lazy_read( 1002 gen_resource_variable_ops.resource_scatter_add( 1003 self.handle, 1004 sparse_delta.indices, 1005 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1006 name=name)) 1007 1008 def scatter_max(self, sparse_delta, use_locking=False, name=None): 1009 """Updates this variable with the max of `tf.IndexedSlices` and itself. 1010 1011 Args: 1012 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 1013 variable. 1014 use_locking: If `True`, use locking during the operation. 1015 name: the name of the operation. 1016 1017 Returns: 1018 The updated variable. 1019 1020 Raises: 1021 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1022 """ 1023 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 1024 raise TypeError(f"Argument `sparse_delta` must be a " 1025 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1026 return self._lazy_read( 1027 gen_resource_variable_ops.resource_scatter_max( 1028 self.handle, 1029 sparse_delta.indices, 1030 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1031 name=name)) 1032 1033 def scatter_min(self, sparse_delta, use_locking=False, name=None): 1034 """Updates this variable with the min of `tf.IndexedSlices` and itself. 1035 1036 Args: 1037 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 1038 variable. 1039 use_locking: If `True`, use locking during the operation. 1040 name: the name of the operation. 1041 1042 Returns: 1043 The updated variable. 1044 1045 Raises: 1046 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1047 """ 1048 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 1049 raise TypeError(f"Argument `sparse_delta` must be a " 1050 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1051 return self._lazy_read( 1052 gen_resource_variable_ops.resource_scatter_min( 1053 self.handle, 1054 sparse_delta.indices, 1055 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1056 name=name)) 1057 1058 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 1059 """Multiply this variable by `tf.IndexedSlices`. 1060 1061 Args: 1062 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 1063 use_locking: If `True`, use locking during the operation. 1064 name: the name of the operation. 1065 1066 Returns: 1067 The updated variable. 1068 1069 Raises: 1070 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1071 """ 1072 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 1073 raise TypeError(f"Argument `sparse_delta` must be a " 1074 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1075 return self._lazy_read( 1076 gen_resource_variable_ops.resource_scatter_mul( 1077 self.handle, 1078 sparse_delta.indices, 1079 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1080 name=name)) 1081 1082 def scatter_div(self, sparse_delta, use_locking=False, name=None): 1083 """Divide this variable by `tf.IndexedSlices`. 1084 1085 Args: 1086 sparse_delta: `tf.IndexedSlices` to divide this variable by. 1087 use_locking: If `True`, use locking during the operation. 1088 name: the name of the operation. 1089 1090 Returns: 1091 The updated variable. 1092 1093 Raises: 1094 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1095 """ 1096 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 1097 raise TypeError(f"Argument `sparse_delta` must be a " 1098 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1099 return self._lazy_read( 1100 gen_resource_variable_ops.resource_scatter_div( 1101 self.handle, 1102 sparse_delta.indices, 1103 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1104 name=name)) 1105 1106 def scatter_update(self, sparse_delta, use_locking=False, name=None): 1107 """Assigns `tf.IndexedSlices` to this variable. 1108 1109 Args: 1110 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 1111 use_locking: If `True`, use locking during the operation. 1112 name: the name of the operation. 1113 1114 Returns: 1115 The updated variable. 1116 1117 Raises: 1118 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1119 """ 1120 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 1121 raise TypeError(f"Argument `sparse_delta` must be a " 1122 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1123 return self._lazy_read( 1124 gen_resource_variable_ops.resource_scatter_update( 1125 self.handle, 1126 sparse_delta.indices, 1127 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1128 name=name)) 1129 1130 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 1131 """Assigns `tf.IndexedSlices` to this variable batch-wise. 1132 1133 Analogous to `batch_gather`. This assumes that this variable and the 1134 sparse_delta IndexedSlices have a series of leading dimensions that are the 1135 same for all of them, and the updates are performed on the last dimension of 1136 indices. In other words, the dimensions should be the following: 1137 1138 `num_prefix_dims = sparse_delta.indices.ndims - 1` 1139 `batch_dim = num_prefix_dims + 1` 1140 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 1141 batch_dim:]` 1142 1143 where 1144 1145 `sparse_delta.updates.shape[:num_prefix_dims]` 1146 `== sparse_delta.indices.shape[:num_prefix_dims]` 1147 `== var.shape[:num_prefix_dims]` 1148 1149 And the operation performed can be expressed as: 1150 1151 `var[i_1, ..., i_n, 1152 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 1153 i_1, ..., i_n, j]` 1154 1155 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 1156 `scatter_update`. 1157 1158 To avoid this operation one can looping over the first `ndims` of the 1159 variable and using `scatter_update` on the subtensors that result of slicing 1160 the first dimension. This is a valid option for `ndims = 1`, but less 1161 efficient than this implementation. 1162 1163 Args: 1164 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 1165 use_locking: If `True`, use locking during the operation. 1166 name: the name of the operation. 1167 1168 Returns: 1169 The updated variable. 1170 1171 Raises: 1172 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1173 """ 1174 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 1175 raise TypeError(f"Argument `sparse_delta` must be a " 1176 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1177 return self._lazy_read( 1178 state_ops.batch_scatter_update( 1179 self, 1180 sparse_delta.indices, 1181 sparse_delta.values, 1182 use_locking=use_locking, 1183 name=name)) 1184 1185 def scatter_nd_sub(self, indices, updates, name=None): 1186 """Applies sparse subtraction to individual values or slices in a Variable. 1187 1188 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1189 1190 `indices` must be integer tensor, containing indices into `ref`. 1191 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1192 1193 The innermost dimension of `indices` (with length `K`) corresponds to 1194 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1195 dimension of `ref`. 1196 1197 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1198 1199 ``` 1200 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1201 ``` 1202 1203 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1204 8 elements. In Python, that update would look like this: 1205 1206 ```python 1207 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1208 indices = tf.constant([[4], [3], [1] ,[7]]) 1209 updates = tf.constant([9, 10, 11, 12]) 1210 op = ref.scatter_nd_sub(indices, updates) 1211 with tf.compat.v1.Session() as sess: 1212 print sess.run(op) 1213 ``` 1214 1215 The resulting update to ref would look like this: 1216 1217 [1, -9, 3, -6, -6, 6, 7, -4] 1218 1219 See `tf.scatter_nd` for more details about how to make updates to 1220 slices. 1221 1222 Args: 1223 indices: The indices to be used in the operation. 1224 updates: The values to be used in the operation. 1225 name: the name of the operation. 1226 1227 Returns: 1228 The updated variable. 1229 """ 1230 return self._lazy_read( 1231 gen_state_ops.resource_scatter_nd_sub( 1232 self.handle, 1233 indices, 1234 ops.convert_to_tensor(updates, self.dtype), 1235 name=name)) 1236 1237 def scatter_nd_add(self, indices, updates, name=None): 1238 """Applies sparse addition to individual values or slices in a Variable. 1239 1240 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1241 1242 `indices` must be integer tensor, containing indices into `ref`. 1243 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1244 1245 The innermost dimension of `indices` (with length `K`) corresponds to 1246 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1247 dimension of `ref`. 1248 1249 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1250 1251 ``` 1252 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1253 ``` 1254 1255 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1256 8 elements. In Python, that update would look like this: 1257 1258 ```python 1259 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1260 indices = tf.constant([[4], [3], [1] ,[7]]) 1261 updates = tf.constant([9, 10, 11, 12]) 1262 add = ref.scatter_nd_add(indices, updates) 1263 with tf.compat.v1.Session() as sess: 1264 print sess.run(add) 1265 ``` 1266 1267 The resulting update to ref would look like this: 1268 1269 [1, 13, 3, 14, 14, 6, 7, 20] 1270 1271 See `tf.scatter_nd` for more details about how to make updates to 1272 slices. 1273 1274 Args: 1275 indices: The indices to be used in the operation. 1276 updates: The values to be used in the operation. 1277 name: the name of the operation. 1278 1279 Returns: 1280 The updated variable. 1281 """ 1282 return self._lazy_read( 1283 gen_state_ops.resource_scatter_nd_add( 1284 self.handle, 1285 indices, 1286 ops.convert_to_tensor(updates, self.dtype), 1287 name=name)) 1288 1289 def scatter_nd_update(self, indices, updates, name=None): 1290 """Applies sparse assignment to individual values or slices in a Variable. 1291 1292 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1293 1294 `indices` must be integer tensor, containing indices into `ref`. 1295 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1296 1297 The innermost dimension of `indices` (with length `K`) corresponds to 1298 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1299 dimension of `ref`. 1300 1301 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1302 1303 ``` 1304 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1305 ``` 1306 1307 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1308 8 elements. In Python, that update would look like this: 1309 1310 ```python 1311 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1312 indices = tf.constant([[4], [3], [1] ,[7]]) 1313 updates = tf.constant([9, 10, 11, 12]) 1314 op = ref.scatter_nd_update(indices, updates) 1315 with tf.compat.v1.Session() as sess: 1316 print sess.run(op) 1317 ``` 1318 1319 The resulting update to ref would look like this: 1320 1321 [1, 11, 3, 10, 9, 6, 7, 12] 1322 1323 See `tf.scatter_nd` for more details about how to make updates to 1324 slices. 1325 1326 Args: 1327 indices: The indices to be used in the operation. 1328 updates: The values to be used in the operation. 1329 name: the name of the operation. 1330 1331 Returns: 1332 The updated variable. 1333 """ 1334 return self._lazy_read( 1335 gen_state_ops.resource_scatter_nd_update( 1336 self.handle, 1337 indices, 1338 ops.convert_to_tensor(updates, self.dtype), 1339 name=name)) 1340 1341 def scatter_nd_max(self, indices, updates, name=None): 1342 """Updates this variable with the max of `tf.IndexedSlices` and itself. 1343 1344 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1345 1346 `indices` must be integer tensor, containing indices into `ref`. 1347 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1348 1349 The innermost dimension of `indices` (with length `K`) corresponds to 1350 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1351 dimension of `ref`. 1352 1353 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1354 1355 ``` 1356 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1357 ``` 1358 1359 See `tf.scatter_nd` for more details about how to make updates to 1360 slices. 1361 1362 Args: 1363 indices: The indices to be used in the operation. 1364 updates: The values to be used in the operation. 1365 name: the name of the operation. 1366 1367 Returns: 1368 The updated variable. 1369 """ 1370 return self._lazy_read( 1371 gen_state_ops.resource_scatter_nd_max( 1372 self.handle, 1373 indices, 1374 ops.convert_to_tensor(updates, self.dtype), 1375 name=name)) 1376 1377 def scatter_nd_min(self, indices, updates, name=None): 1378 """Updates this variable with the min of `tf.IndexedSlices` and itself. 1379 1380 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1381 1382 `indices` must be integer tensor, containing indices into `ref`. 1383 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1384 1385 The innermost dimension of `indices` (with length `K`) corresponds to 1386 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1387 dimension of `ref`. 1388 1389 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1390 1391 ``` 1392 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1393 ``` 1394 1395 See `tf.scatter_nd` for more details about how to make updates to 1396 slices. 1397 1398 Args: 1399 indices: The indices to be used in the operation. 1400 updates: The values to be used in the operation. 1401 name: the name of the operation. 1402 1403 Returns: 1404 The updated variable. 1405 """ 1406 return self._lazy_read( 1407 gen_state_ops.resource_scatter_nd_min( 1408 self.handle, 1409 indices, 1410 ops.convert_to_tensor(updates, self.dtype), 1411 name=name)) 1412 1413 def _write_object_proto(self, proto, options): 1414 """Writes additional information of the variable into the SavedObject proto. 1415 1416 Subclasses of ResourceVariables could choose to override this method to 1417 customize extra information to provide when saving a SavedModel. 1418 1419 Ideally, this should contain the logic in 1420 write_object_proto_for_resource_variable but `DistributedValue` is an 1421 outlier at the momemnt. Once `DistributedValue` becomes a proper 1422 ResourceVariable, we should remove the helper method below. 1423 1424 Args: 1425 proto: `SavedObject` proto to update. 1426 options: A `SaveOption` instance that configures save behavior. 1427 """ 1428 write_object_proto_for_resource_variable(self, proto, options) 1429 1430 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 1431 end_mask, ellipsis_mask, new_axis_mask, 1432 shrink_axis_mask): 1433 with _handle_graph(self.handle), self._assign_dependencies(): 1434 return self._lazy_read( 1435 gen_array_ops.resource_strided_slice_assign( 1436 ref=self.handle, 1437 begin=begin, 1438 end=end, 1439 strides=strides, 1440 value=ops.convert_to_tensor(value, dtype=self.dtype), 1441 name=name, 1442 begin_mask=begin_mask, 1443 end_mask=end_mask, 1444 ellipsis_mask=ellipsis_mask, 1445 new_axis_mask=new_axis_mask, 1446 shrink_axis_mask=shrink_axis_mask)) 1447 1448 def __complex__(self): 1449 return complex(self.value().numpy()) 1450 1451 def __int__(self): 1452 return int(self.value().numpy()) 1453 1454 def __long__(self): 1455 return long(self.value().numpy()) 1456 1457 def __float__(self): 1458 return float(self.value().numpy()) 1459 1460 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1461 del name 1462 if dtype is not None and not dtype.is_compatible_with(self.dtype): 1463 raise ValueError( 1464 f"Incompatible type conversion requested to type {dtype.name} for " 1465 f"`tf.Variable of type {self.dtype.name}. (Variable: {self})") 1466 if as_ref: 1467 return self.read_value().op.inputs[0] 1468 else: 1469 return self.value() 1470 1471 def __iadd__(self, unused_other): 1472 raise RuntimeError("`variable += value` with `tf.Variable`s is not " 1473 "supported. Use `variable.assign_add(value)` to modify " 1474 "the variable, or `out = variable + value` if you " 1475 "need to get a new output Tensor.") 1476 1477 def __isub__(self, unused_other): 1478 raise RuntimeError("`variable -= value` with `tf.Variable`s is not " 1479 "supported. Use `variable.assign_sub(value)` to modify " 1480 "the variable, or `out = variable * value` if you " 1481 "need to get a new output Tensor.") 1482 1483 def __imul__(self, unused_other): 1484 raise RuntimeError("`var *= value` with `tf.Variable`s is not " 1485 "supported. Use `var.assign(var * value)` to modify " 1486 "the variable, or `out = var * value` if you " 1487 "need to get a new output Tensor.") 1488 1489 def __idiv__(self, unused_other): 1490 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 1491 "supported. Use `var.assign(var / value)` to modify " 1492 "the variable, or `out = var / value` if you " 1493 "need to get a new output Tensor.") 1494 1495 def __itruediv__(self, unused_other): 1496 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 1497 "supported. Use `var.assign(var / value)` to modify " 1498 "the variable, or `out = var / value` if you " 1499 "need to get a new output Tensor.") 1500 1501 def __irealdiv__(self, unused_other): 1502 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 1503 "supported. Use `var.assign(var / value)` to modify " 1504 "the variable, or `out = var / value` if you " 1505 "need to get a new output Tensor.") 1506 1507 def __ipow__(self, unused_other): 1508 raise RuntimeError("`var **= value` with `tf.Variable`s is not " 1509 "supported. Use `var.assign(var ** value)` to modify " 1510 "the variable, or `out = var ** value` if you " 1511 "need to get a new output Tensor.") 1512 1513 1514class ResourceVariable(BaseResourceVariable): 1515 """Variable based on resource handles. 1516 1517 See the [Variables How To](https://tensorflow.org/guide/variables) 1518 for a high level overview. 1519 1520 A `ResourceVariable` allows you to maintain state across subsequent calls to 1521 session.run. 1522 1523 The `ResourceVariable` constructor requires an initial value for the variable, 1524 which can be a `Tensor` of any type and shape. The initial value defines the 1525 type and shape of the variable. After construction, the type and shape of 1526 the variable are fixed. The value can be changed using one of the assign 1527 methods. 1528 1529 Just like any `Tensor`, variables created with 1530 `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the 1531 graph. Additionally, all the operators overloaded for the `Tensor` class are 1532 carried over to variables, so you can also add nodes to the graph by just 1533 doing arithmetic on variables. 1534 1535 Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each 1536 usage of a ResourceVariable in a TensorFlow graph adds a read_value operation 1537 to the graph. The Tensors returned by a read_value operation are guaranteed to 1538 see all modifications to the value of the variable which happen in any 1539 operation on which the read_value depends on (either directly, indirectly, or 1540 via a control dependency) and guaranteed to not see any modification to the 1541 value of the variable from operations that depend on the read_value operation. 1542 Updates from operations that have no dependency relationship to the read_value 1543 operation might or might not be visible to read_value. 1544 1545 For example, if there is more than one assignment to a ResourceVariable in 1546 a single session.run call there is a well-defined value for each operation 1547 which uses the variable's value if the assignments and the read are connected 1548 by edges in the graph. Consider the following example, in which two writes 1549 can cause tf.Variable and tf.ResourceVariable to behave differently: 1550 1551 ```python 1552 a = tf.Variable(1.0, use_resource=True) 1553 a.initializer.run() 1554 1555 assign = a.assign(2.0) 1556 with tf.control_dependencies([assign]): 1557 b = a.read_value() 1558 with tf.control_dependencies([b]): 1559 other_assign = a.assign(3.0) 1560 with tf.control_dependencies([other_assign]): 1561 # Will print 2.0 because the value was read before other_assign ran. If 1562 # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. 1563 tf.compat.v1.Print(b, [b]).eval() 1564 ``` 1565 """ 1566 1567 def __init__( 1568 self, # pylint: disable=super-init-not-called 1569 initial_value=None, 1570 trainable=None, 1571 collections=None, 1572 validate_shape=True, # pylint: disable=unused-argument 1573 caching_device=None, 1574 name=None, 1575 dtype=None, 1576 variable_def=None, 1577 import_scope=None, 1578 constraint=None, 1579 distribute_strategy=None, 1580 synchronization=None, 1581 aggregation=None, 1582 shape=None): 1583 """Creates a variable. 1584 1585 Args: 1586 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1587 which is the initial value for the Variable. Can also be a callable with 1588 no argument that returns the initial value when called. (Note that 1589 initializer functions from init_ops.py must first be bound to a shape 1590 before being used here.) 1591 trainable: If `True`, the default, also adds the variable to the graph 1592 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1593 the default list of variables to use by the `Optimizer` classes. 1594 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 1595 which case it defaults to `False`. 1596 collections: List of graph collections keys. The new variable is added to 1597 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1598 validate_shape: If `False`, allows the variable to be initialized with a 1599 value of unknown shape. If `True`, the default, the shape of 1600 `initial_value` must be known. 1601 caching_device: Optional device string or function describing where the 1602 Variable should be cached for reading. Defaults to the Variable's 1603 device. If not `None`, caches on another device. Typical use is to 1604 cache on the device where the Ops using the Variable reside, to 1605 deduplicate copying through `Switch` and other conditional statements. 1606 name: Optional name for the variable. Defaults to `'Variable'` and gets 1607 uniquified automatically. 1608 dtype: If set, initial_value will be converted to the given type. If None, 1609 either the datatype will be kept (if initial_value is a Tensor) or 1610 float32 will be used (if it is a Python object convertible to a Tensor). 1611 variable_def: `VariableDef` protocol buffer. If not None, recreates the 1612 `ResourceVariable` object with its contents. `variable_def` and other 1613 arguments (except for import_scope) are mutually exclusive. 1614 import_scope: Optional `string`. Name scope to add to the 1615 ResourceVariable. Only used when `variable_def` is provided. 1616 constraint: An optional projection function to be applied to the variable 1617 after being updated by an `Optimizer` (e.g. used to implement norm 1618 constraints or value constraints for layer weights). The function must 1619 take as input the unprojected Tensor representing the value of the 1620 variable and return the Tensor for the projected value (which must have 1621 the same shape). Constraints are not safe to use when doing asynchronous 1622 distributed training. 1623 distribute_strategy: The tf.distribute.Strategy this variable is being 1624 created inside of. 1625 synchronization: Indicates when a distributed a variable will be 1626 aggregated. Accepted values are constants defined in the class 1627 `tf.VariableSynchronization`. By default the synchronization is set to 1628 `AUTO` and the current `DistributionStrategy` chooses when to 1629 synchronize. 1630 aggregation: Indicates how a distributed variable will be aggregated. 1631 Accepted values are constants defined in the class 1632 `tf.VariableAggregation`. 1633 shape: (optional) The shape of this variable. If None, the shape of 1634 `initial_value` will be used. When setting this argument to 1635 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1636 can be assigned with values of different shapes. 1637 1638 Raises: 1639 ValueError: If the initial value is not specified, or does not have a 1640 shape and `validate_shape` is `True`. 1641 1642 @compatibility(eager) 1643 When Eager Execution is enabled, the default for the `collections` argument 1644 is `None`, which signifies that this `Variable` will not be added to any 1645 collections. 1646 @end_compatibility 1647 """ 1648 if variable_def: 1649 if initial_value is not None: 1650 raise ValueError(f"The variable_def and initial_value args to " 1651 f"`tf.Variable` are mutually exclusive, but got both: " 1652 f"variable_def={variable_def},\n" 1653 f"initial_value={initial_value}") 1654 if context.executing_eagerly(): 1655 raise ValueError(f"Creating a `tf.Variable` with a `variable_def` arg " 1656 f"is not supported when eager execution is enabled. " 1657 f"Got: variable_def={variable_def}") 1658 self._init_from_proto( 1659 variable_def, 1660 import_scope=import_scope, 1661 validate_shape=validate_shape) 1662 else: 1663 self._init_from_args( 1664 initial_value=initial_value, 1665 trainable=trainable, 1666 collections=collections, 1667 caching_device=caching_device, 1668 name=name, 1669 dtype=dtype, 1670 constraint=constraint, 1671 synchronization=synchronization, 1672 aggregation=aggregation, 1673 shape=shape, 1674 distribute_strategy=distribute_strategy, 1675 validate_shape=validate_shape, 1676 ) 1677 1678 def _init_from_args( 1679 self, 1680 initial_value=None, 1681 trainable=None, 1682 collections=None, 1683 caching_device=None, 1684 name=None, 1685 dtype=None, 1686 constraint=None, 1687 synchronization=None, 1688 aggregation=None, 1689 distribute_strategy=None, 1690 shape=None, 1691 validate_shape=True, 1692 ): 1693 """Creates a variable. 1694 1695 Args: 1696 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1697 which is the initial value for the Variable. The initial value must have 1698 a shape specified unless `validate_shape` is set to False. Can also be a 1699 callable with no argument that returns the initial value when called. 1700 (Note that initializer functions from init_ops.py must first be bound to 1701 a shape before being used here.) 1702 trainable: If `True`, the default, also adds the variable to the graph 1703 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1704 the default list of variables to use by the `Optimizer` classes. 1705 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 1706 which case it defaults to `False`. 1707 collections: List of graph collections keys. The new variable is added to 1708 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1709 caching_device: Optional device string or function describing where the 1710 Variable should be cached for reading. Defaults to the Variable's 1711 device. If not `None`, caches on another device. Typical use is to 1712 cache on the device where the Ops using the Variable reside, to 1713 deduplicate copying through `Switch` and other conditional statements. 1714 name: Optional name for the variable. Defaults to `'Variable'` and gets 1715 uniquified automatically. 1716 dtype: If set, initial_value will be converted to the given type. If None, 1717 either the datatype will be kept (if initial_value is a Tensor) or 1718 float32 will be used (if it is a Python object convertible to a Tensor). 1719 constraint: An optional projection function to be applied to the variable 1720 after being updated by an `Optimizer` (e.g. used to implement norm 1721 constraints or value constraints for layer weights). The function must 1722 take as input the unprojected Tensor representing the value of the 1723 variable and return the Tensor for the projected value (which must have 1724 the same shape). Constraints are not safe to use when doing asynchronous 1725 distributed training. 1726 synchronization: Indicates when a distributed a variable will be 1727 aggregated. Accepted values are constants defined in the class 1728 `tf.VariableSynchronization`. By default the synchronization is set to 1729 `AUTO` and the current `DistributionStrategy` chooses when to 1730 synchronize. 1731 aggregation: Indicates how a distributed variable will be aggregated. 1732 Accepted values are constants defined in the class 1733 `tf.VariableAggregation`. 1734 distribute_strategy: DistributionStrategy under which this variable was 1735 created. 1736 shape: (optional) The shape of this variable. If None, the shape of 1737 `initial_value` will be used. When setting this argument to 1738 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1739 can be assigned with values of different shapes. 1740 validate_shape: If `False`, allows the variable to be initialized with a 1741 value of unknown shape. If `True`, the default, the shape of 1742 `initial_value` must be known. 1743 1744 Raises: 1745 ValueError: If the initial value is not specified, or does not have a 1746 shape and `validate_shape` is `True`. 1747 1748 @compatibility(eager) 1749 When Eager Execution is enabled, variables are never added to collections. 1750 It is not implicitly added to the `GLOBAL_VARIABLES` or 1751 `TRAINABLE_VARIABLES` collections, and the `collections` argument is 1752 ignored. 1753 @end_compatibility 1754 """ 1755 synchronization, aggregation, trainable = ( 1756 variables.validate_synchronization_aggregation_trainable( 1757 synchronization, aggregation, trainable, name)) 1758 if initial_value is None: 1759 raise ValueError("The `initial_value` arg to `tf.Variable` must " 1760 "be specified except when you are not providing a " 1761 "`variable_def`. You provided neither.") 1762 init_from_fn = callable(initial_value) 1763 1764 if isinstance(initial_value, ops.Tensor) and hasattr( 1765 initial_value, "graph") and initial_value.graph.building_function: 1766 raise ValueError(f"Argument `initial_value` ({initial_value}) could not " 1767 "be lifted out of a `tf.function`. " 1768 f"(Tried to create variable with name='{name}'). " 1769 "To avoid this error, when constructing `tf.Variable`s " 1770 "inside of `tf.function` you can create the " 1771 "`initial_value` tensor in a " 1772 "`tf.init_scope` or pass a callable `initial_value` " 1773 "(e.g., `tf.Variable(lambda : " 1774 "tf.truncated_normal([10, 40]))`). " 1775 "Please file a feature request if this " 1776 "restriction inconveniences you.") 1777 1778 if collections is None: 1779 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1780 if not isinstance(collections, (list, tuple, set)): 1781 raise ValueError( 1782 f"collections argument to Variable constructor must be a list, " 1783 f"tuple, or set. Got {collections} of type {type(collections)}") 1784 if constraint is not None and not callable(constraint): 1785 raise ValueError(f"Argument `constraint` must be None or a callable. " 1786 f"a callable. Got a {type(constraint)}: {constraint}") 1787 1788 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1789 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1790 with ops.init_scope(): 1791 self._in_graph_mode = not context.executing_eagerly() 1792 with ops.name_scope( 1793 name, 1794 "Variable", [] if init_from_fn else [initial_value], 1795 skip_on_eager=False) as name: 1796 # pylint: disable=protected-access 1797 handle_name = ops.name_from_scope_name(name) 1798 if self._in_graph_mode: 1799 shared_name = handle_name 1800 unique_id = shared_name 1801 else: 1802 # When in eager mode use a uid for the shared_name, to prevent 1803 # accidental sharing. 1804 unique_id = "%s_%d" % (handle_name, ops.uid()) 1805 shared_name = None # Never shared 1806 # Use attr_scope and device(None) to simulate the behavior of 1807 # colocate_with when the variable we want to colocate with doesn't 1808 # yet exist. 1809 device_context_manager = ( 1810 ops.device if self._in_graph_mode else ops.NullContextmanager) 1811 attr = attr_value_pb2.AttrValue( 1812 list=attr_value_pb2.AttrValue.ListValue( 1813 s=[compat.as_bytes("loc:@%s" % handle_name)])) 1814 with ops.get_default_graph()._attr_scope({"_class": attr}): 1815 with ops.name_scope("Initializer"), device_context_manager(None): 1816 if init_from_fn: 1817 initial_value = initial_value() 1818 if isinstance(initial_value, trackable.CheckpointInitialValue): 1819 self._maybe_initialize_trackable() 1820 self._update_uid = initial_value.checkpoint_position.restore_uid 1821 initial_value = initial_value.wrapped_value 1822 initial_value = ops.convert_to_tensor( 1823 initial_value, name="initial_value", dtype=dtype) 1824 if shape is not None: 1825 if not initial_value.shape.is_compatible_with(shape): 1826 raise ValueError( 1827 f"In this `tf.Variable` creation, the initial value's shape " 1828 f"({initial_value.shape}) is not compatible with " 1829 f"the explicitly supplied `shape` argument ({shape}).") 1830 else: 1831 shape = initial_value.shape 1832 handle = eager_safe_variable_handle( 1833 initial_value=initial_value, 1834 shape=shape, 1835 shared_name=shared_name, 1836 name=name, 1837 graph_mode=self._in_graph_mode) 1838 handle._parent_trackable = weakref.ref(self) 1839 # pylint: disable=protected-access 1840 if (self._in_graph_mode and initial_value is not None and 1841 initial_value.op._get_control_flow_context() is not None): 1842 raise ValueError( 1843 f"The `initial_value` passed to `tf.Variable` {name} is from " 1844 f"inside a control-flow construct, such as a loop or " 1845 f"conditional. When creating a " 1846 f"`tf.Variable` inside a loop or conditional, use a lambda as " 1847 f"the `initial_value`. Got: initial_value=({initial_value})") 1848 # pylint: enable=protected-access 1849 dtype = initial_value.dtype.base_dtype 1850 1851 if self._in_graph_mode: 1852 with ops.name_scope("IsInitialized"): 1853 is_initialized_op = ( 1854 gen_resource_variable_ops.var_is_initialized_op(handle)) 1855 if initial_value is not None: 1856 # pylint: disable=g-backslash-continuation 1857 with ops.name_scope("Assign") as n, \ 1858 ops.colocate_with(None, ignore_existing=True), \ 1859 ops.device(handle.device): 1860 # pylint: disable=protected-access 1861 initializer_op = ( 1862 gen_resource_variable_ops.assign_variable_op( 1863 handle, 1864 variables._try_guard_against_uninitialized_dependencies( 1865 name, initial_value), 1866 name=n)) 1867 # pylint: enable=protected-access 1868 # pylint: enable=g-backslash-continuation 1869 with ops.name_scope("Read"): 1870 # Manually assign reads to the handle's device to avoid log 1871 # messages. 1872 with ops.device(handle.device): 1873 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 1874 _maybe_set_handle_data(dtype, handle, value) 1875 graph_element = value 1876 if caching_device is not None: 1877 # Variables may be created in a tf.device() or ops.colocate_with() 1878 # context. At the same time, users would expect caching device to 1879 # be independent of this context, and/or would not expect the 1880 # current device context to be merged with the caching device 1881 # spec. Therefore we reset the colocation stack before creating 1882 # the cached value. Note that resetting the colocation stack will 1883 # also reset the device stack. 1884 with ops.colocate_with(None, ignore_existing=True): 1885 with ops.device(caching_device): 1886 cached_value = array_ops.identity(value) 1887 else: 1888 cached_value = None 1889 else: 1890 gen_resource_variable_ops.assign_variable_op(handle, initial_value) 1891 is_initialized_op = None 1892 initializer_op = None 1893 graph_element = None 1894 if caching_device: 1895 with ops.device(caching_device): 1896 cached_value = gen_resource_variable_ops.read_variable_op( 1897 handle, dtype) 1898 _maybe_set_handle_data(dtype, handle, cached_value) 1899 else: 1900 cached_value = None 1901 1902 if cached_value is not None: 1903 # Store the variable object so that the original variable can be 1904 # accessed to generate functions that are compatible with SavedModel. 1905 cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access 1906 1907 if not context.executing_eagerly(): 1908 # Eager variables are only added to collections if they are part of an 1909 # eager variable store (otherwise in an interactive session they would 1910 # hog memory and cause OOM). This is done in ops/variable_scope.py. 1911 ops.add_to_collections(collections, self) 1912 elif ops.GraphKeys.GLOBAL_STEP in collections: 1913 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) 1914 initial_value = initial_value if self._in_graph_mode else None 1915 super(ResourceVariable, self).__init__( 1916 trainable=trainable, 1917 shape=shape, 1918 dtype=dtype, 1919 handle=handle, 1920 synchronization=synchronization, 1921 constraint=constraint, 1922 aggregation=aggregation, 1923 distribute_strategy=distribute_strategy, 1924 name=name, 1925 unique_id=unique_id, 1926 handle_name=handle_name, 1927 graph_element=graph_element, 1928 initial_value=initial_value, 1929 initializer_op=initializer_op, 1930 is_initialized_op=is_initialized_op, 1931 cached_value=cached_value, 1932 caching_device=caching_device, 1933 validate_shape=validate_shape, 1934 ) 1935 1936 def _init_from_proto(self, 1937 variable_def, 1938 import_scope=None, 1939 validate_shape=True): 1940 """Initializes from `VariableDef` proto.""" 1941 # Note that init_from_proto is currently not supported in Eager mode. 1942 assert not context.executing_eagerly() 1943 self._in_graph_mode = True 1944 assert isinstance(variable_def, variable_pb2.VariableDef) 1945 if not variable_def.is_resource: 1946 raise ValueError(f"The `variable_def` you passed to `tf.Variable` is " 1947 f"Trying to restore a TF 1.x Reference Variable " 1948 f"as a TF 2.x ResourceVariable. This is unsupported. " 1949 f"Got variable_def={variable_def}") 1950 1951 # Create from variable_def. 1952 g = ops.get_default_graph() 1953 self._handle = g.as_graph_element( 1954 ops.prepend_name_scope( 1955 variable_def.variable_name, import_scope=import_scope)) 1956 self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape")) 1957 self._handle_name = self._handle.name 1958 self._unique_id = self._handle_name 1959 self._initializer_op = g.as_graph_element( 1960 ops.prepend_name_scope( 1961 variable_def.initializer_name, import_scope=import_scope)) 1962 # Check whether initial_value_name exists for backwards compatibility. 1963 if (hasattr(variable_def, "initial_value_name") and 1964 variable_def.initial_value_name): 1965 self._initial_value = g.as_graph_element( 1966 ops.prepend_name_scope( 1967 variable_def.initial_value_name, import_scope=import_scope)) 1968 else: 1969 self._initial_value = None 1970 synchronization, aggregation, trainable = ( 1971 variables.validate_synchronization_aggregation_trainable( 1972 variable_def.synchronization, variable_def.aggregation, 1973 variable_def.trainable, variable_def.variable_name)) 1974 self._synchronization = synchronization 1975 self._aggregation = aggregation 1976 self._trainable = trainable 1977 if variable_def.snapshot_name: 1978 snapshot = g.as_graph_element( 1979 ops.prepend_name_scope( 1980 variable_def.snapshot_name, import_scope=import_scope)) 1981 if snapshot.op.type != "ReadVariableOp": 1982 self._cached_value = snapshot 1983 else: 1984 self._cached_value = None 1985 while snapshot.op.type != "ReadVariableOp": 1986 snapshot = snapshot.op.inputs[0] 1987 self._graph_element = snapshot 1988 else: 1989 self._cached_value = None 1990 # Legacy case for protos without the snapshot name; assume it's the 1991 # following. 1992 self._graph_element = g.get_tensor_by_name(self._handle.op.name + 1993 "/Read/ReadVariableOp:0") 1994 if variable_def.HasField("save_slice_info_def"): 1995 self._save_slice_info = variables.Variable.SaveSliceInfo( 1996 save_slice_info_def=variable_def.save_slice_info_def, 1997 import_scope=import_scope) 1998 else: 1999 self._save_slice_info = None 2000 self._caching_device = None 2001 self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) 2002 self._constraint = None 2003 self._validate_shape = validate_shape 2004 2005 2006class UninitializedVariable(BaseResourceVariable): 2007 """A variable with no initializer.""" 2008 2009 def __init__( # pylint: disable=super-init-not-called 2010 self, 2011 trainable=None, 2012 caching_device=None, 2013 name=None, 2014 shape=None, 2015 dtype=None, 2016 constraint=None, 2017 synchronization=None, 2018 aggregation=None, 2019 extra_handle_data=None, 2020 distribute_strategy=None, 2021 **unused_kwargs): 2022 """Creates the variable handle. 2023 2024 Args: 2025 trainable: If `True`, GradientTapes automatically watch uses of this 2026 Variable. 2027 caching_device: Optional device string or function describing where the 2028 Variable should be cached for reading. Defaults to the Variable's 2029 device. If not `None`, caches on another device. Typical use is to 2030 cache on the device where the Ops using the Variable reside, to 2031 deduplicate copying through `Switch` and other conditional statements. 2032 name: Optional name for the variable. Defaults to `'Variable'` and gets 2033 uniquified automatically. 2034 shape: The variable's shape. 2035 dtype: The variable's dtype. 2036 constraint: An optional projection function to be applied to the variable 2037 after being updated by an `Optimizer` (e.g. used to implement norm 2038 constraints or value constraints for layer weights). The function must 2039 take as input the unprojected Tensor representing the value of the 2040 variable and return the Tensor for the projected value (which must have 2041 the same shape). Constraints are not safe to use when doing asynchronous 2042 distributed training. 2043 synchronization: Indicates when a distributed a variable will be 2044 aggregated. Accepted values are constants defined in the class 2045 `tf.VariableSynchronization`. By default the synchronization is set to 2046 `AUTO` and the current `DistributionStrategy` chooses when to 2047 synchronize. 2048 aggregation: Indicates how a distributed variable will be aggregated. 2049 Accepted values are constants defined in the class 2050 `tf.VariableAggregation`. 2051 extra_handle_data: Optional, another resource handle or Tensor with handle 2052 data to merge with `shape` and `dtype`. 2053 distribute_strategy: The tf.distribute.Strategy this variable is being 2054 created inside of. 2055 """ 2056 with ops.init_scope(): 2057 # Here we are detecting eagerness within an init_scope, so this will only 2058 # be true when we are running in TF1 graph mode. 2059 self._in_graph_mode = not context.executing_eagerly() 2060 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 2061 handle_name = ops.name_from_scope_name(name) 2062 if self._in_graph_mode: 2063 shared_name = handle_name 2064 unique_id = shared_name 2065 else: 2066 unique_id = "%s_%d" % (handle_name, ops.uid()) 2067 shared_name = None # Never shared 2068 handle = _variable_handle_from_shape_and_dtype( 2069 shape=shape, 2070 dtype=dtype, 2071 shared_name=shared_name, 2072 name=name, 2073 graph_mode=self._in_graph_mode, 2074 initial_value=extra_handle_data) 2075 handle._parent_trackable = weakref.ref(self) 2076 2077 if self._in_graph_mode: 2078 # We only need to add the read_variable_op in TF1. 2079 with ops.name_scope("Read"): 2080 # Manually assign reads to the handle's device to avoid log 2081 # messages. 2082 with ops.device(handle.device): 2083 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 2084 _maybe_set_handle_data(dtype, handle, value) 2085 graph_element = value 2086 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) 2087 # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable, 2088 # because retraining or frozen use of imported SavedModels is 2089 # controlled at higher levels of model building. 2090 else: 2091 graph_element = None 2092 super(UninitializedVariable, self).__init__( 2093 distribute_strategy=distribute_strategy, 2094 shape=shape, 2095 dtype=dtype, 2096 unique_id=unique_id, 2097 handle_name=handle_name, 2098 constraint=constraint, 2099 handle=handle, 2100 graph_element=graph_element, 2101 trainable=trainable, 2102 synchronization=synchronization, 2103 aggregation=aggregation, 2104 in_graph_mode=self._in_graph_mode) 2105 2106 2107_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) 2108math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access 2109 2110 2111def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): 2112 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 2113 2114 2115# Register a conversion function which reads the value of the variable, 2116# allowing instances of the class to be used as tensors. 2117ops.register_tensor_conversion_function(BaseResourceVariable, 2118 _dense_var_to_tensor) 2119 2120 2121class _UnreadVariable(BaseResourceVariable): 2122 """Represents a future for a read of a variable. 2123 2124 Pretends to be the tensor if anyone looks. 2125 """ 2126 2127 def __init__(self, handle, dtype, shape, in_graph_mode, parent_op, unique_id): 2128 if isinstance(handle, ops.EagerTensor): 2129 handle_name = "" 2130 else: 2131 handle_name = handle.name 2132 # Only create a graph_element if we're in session.run-land as only 2133 # session.run requires a preexisting tensor to evaluate. Otherwise we can 2134 # avoid accidentally reading the variable. 2135 if context.executing_eagerly() or ops.inside_function(): 2136 graph_element = None 2137 else: 2138 with ops.control_dependencies([parent_op]): 2139 graph_element = gen_resource_variable_ops.read_variable_op( 2140 handle, dtype) 2141 _maybe_set_handle_data(dtype, handle, graph_element) 2142 super(_UnreadVariable, self).__init__( 2143 handle=handle, 2144 shape=shape, 2145 handle_name=handle_name, 2146 unique_id=unique_id, 2147 dtype=dtype, 2148 graph_element=graph_element) 2149 self._parent_op = parent_op 2150 2151 @property 2152 def name(self): 2153 if self._in_graph_mode: 2154 return self._parent_op.name 2155 else: 2156 return "UnreadVariable" 2157 2158 def value(self): 2159 return self._read_variable_op() 2160 2161 def read_value(self): 2162 return self._read_variable_op() 2163 2164 def _read_variable_op(self): 2165 with ops.control_dependencies([self._parent_op]): 2166 result = gen_resource_variable_ops.read_variable_op( 2167 self._handle, self._dtype) 2168 _maybe_set_handle_data(self._dtype, self._handle, result) 2169 return result 2170 2171 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 2172 with ops.control_dependencies([self._parent_op]): 2173 return super(_UnreadVariable, self).assign_sub(delta, use_locking, name, 2174 read_value) 2175 2176 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 2177 with ops.control_dependencies([self._parent_op]): 2178 return super(_UnreadVariable, self).assign_add(delta, use_locking, name, 2179 read_value) 2180 2181 def assign(self, value, use_locking=None, name=None, read_value=True): 2182 with ops.control_dependencies([self._parent_op]): 2183 return super(_UnreadVariable, self).assign(value, use_locking, name, 2184 read_value) 2185 2186 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 2187 with ops.control_dependencies([self._parent_op]): 2188 return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking, 2189 name) 2190 2191 def scatter_add(self, sparse_delta, use_locking=False, name=None): 2192 with ops.control_dependencies([self._parent_op]): 2193 return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking, 2194 name) 2195 2196 def scatter_max(self, sparse_delta, use_locking=False, name=None): 2197 with ops.control_dependencies([self._parent_op]): 2198 return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking, 2199 name) 2200 2201 def scatter_min(self, sparse_delta, use_locking=False, name=None): 2202 with ops.control_dependencies([self._parent_op]): 2203 return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking, 2204 name) 2205 2206 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 2207 with ops.control_dependencies([self._parent_op]): 2208 return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking, 2209 name) 2210 2211 def scatter_div(self, sparse_delta, use_locking=False, name=None): 2212 with ops.control_dependencies([self._parent_op]): 2213 return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking, 2214 name) 2215 2216 def scatter_update(self, sparse_delta, use_locking=False, name=None): 2217 with ops.control_dependencies([self._parent_op]): 2218 return super(_UnreadVariable, 2219 self).scatter_update(sparse_delta, use_locking, name) 2220 2221 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 2222 with ops.control_dependencies([self._parent_op]): 2223 return super(_UnreadVariable, 2224 self).batch_scatter_update(sparse_delta, use_locking, name) 2225 2226 def scatter_nd_sub(self, indices, updates, name=None): 2227 with ops.control_dependencies([self._parent_op]): 2228 return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name) 2229 2230 def scatter_nd_add(self, indices, updates, name=None): 2231 with ops.control_dependencies([self._parent_op]): 2232 return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name) 2233 2234 def scatter_nd_update(self, indices, updates, name=None): 2235 with ops.control_dependencies([self._parent_op]): 2236 return super(_UnreadVariable, 2237 self).scatter_nd_update(indices, updates, name) 2238 2239 def scatter_nd_max(self, indices, updates, name=None): 2240 with ops.control_dependencies([self._parent_op]): 2241 return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name) 2242 2243 def scatter_nd_min(self, indices, updates, name=None): 2244 with ops.control_dependencies([self._parent_op]): 2245 return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name) 2246 2247 @property 2248 def op(self): 2249 """The op for this variable.""" 2250 return self._parent_op 2251 2252 2253@ops.RegisterGradient("ReadVariableOp") 2254def _ReadGrad(_, grad): 2255 """Gradient for read op.""" 2256 return grad 2257 2258 2259def variable_shape(handle, out_type=dtypes.int32): 2260 handle_data = get_eager_safe_handle_data(handle) 2261 if handle_data is None or not handle_data.is_set: 2262 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 2263 shape_proto = handle_data.shape_and_type[0].shape 2264 if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim): 2265 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 2266 return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type) 2267 2268 2269@ops.RegisterGradient("ResourceGather") 2270def _GatherGrad(op, grad): 2271 """Gradient for gather op.""" 2272 # Build appropriately shaped IndexedSlices 2273 handle = op.inputs[0] 2274 indices = op.inputs[1] 2275 params_shape = variable_shape(handle) 2276 size = array_ops.expand_dims(array_ops.size(indices), 0) 2277 values_shape = array_ops.concat([size, params_shape[1:]], 0) 2278 values = array_ops.reshape(grad, values_shape) 2279 indices = array_ops.reshape(indices, size) 2280 return (indexed_slices.IndexedSlices(values, indices, params_shape), None) 2281 2282 2283def _to_proto_fn(v, export_scope=None): 2284 """Converts Variable and ResourceVariable to VariableDef for collections.""" 2285 return v.to_proto(export_scope=export_scope) 2286 2287 2288def _from_proto_fn(v, import_scope=None): 2289 """Creates Variable or ResourceVariable from VariableDef as needed.""" 2290 if v.is_resource: 2291 return ResourceVariable.from_proto(v, import_scope=import_scope) 2292 return variables.Variable.from_proto(v, import_scope=import_scope) 2293 2294 2295ops.register_proto_function( 2296 ops.GraphKeys.GLOBAL_VARIABLES, 2297 proto_type=variable_pb2.VariableDef, 2298 to_proto=_to_proto_fn, 2299 from_proto=_from_proto_fn) 2300ops.register_proto_function( 2301 ops.GraphKeys.TRAINABLE_VARIABLES, 2302 proto_type=variable_pb2.VariableDef, 2303 to_proto=_to_proto_fn, 2304 from_proto=_from_proto_fn) 2305ops.register_proto_function( 2306 ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 2307 proto_type=variable_pb2.VariableDef, 2308 to_proto=_to_proto_fn, 2309 from_proto=_from_proto_fn) 2310ops.register_proto_function( 2311 ops.GraphKeys.LOCAL_VARIABLES, 2312 proto_type=variable_pb2.VariableDef, 2313 to_proto=_to_proto_fn, 2314 from_proto=_from_proto_fn) 2315ops.register_proto_function( 2316 ops.GraphKeys.MODEL_VARIABLES, 2317 proto_type=variable_pb2.VariableDef, 2318 to_proto=_to_proto_fn, 2319 from_proto=_from_proto_fn) 2320ops.register_proto_function( 2321 ops.GraphKeys.GLOBAL_STEP, 2322 proto_type=variable_pb2.VariableDef, 2323 to_proto=_to_proto_fn, 2324 from_proto=_from_proto_fn) 2325ops.register_proto_function( 2326 ops.GraphKeys.METRIC_VARIABLES, 2327 proto_type=variable_pb2.VariableDef, 2328 to_proto=_to_proto_fn, 2329 from_proto=_from_proto_fn) 2330 2331 2332@tf_export("__internal__.ops.is_resource_variable", v1=[]) 2333def is_resource_variable(var): 2334 """"Returns True if `var` is to be considered a ResourceVariable.""" 2335 return isinstance(var, BaseResourceVariable) or hasattr( 2336 var, "_should_act_as_resource_variable") 2337 2338 2339def copy_to_graph_uninitialized(var): 2340 """Copies an existing variable to a new graph, with no initializer.""" 2341 # Like ResourceVariable.__deepcopy__, but does not set an initializer on the 2342 # new variable. 2343 # pylint: disable=protected-access 2344 new_variable = UninitializedVariable( 2345 trainable=var.trainable, 2346 constraint=var._constraint, 2347 shape=var.shape, 2348 dtype=var.dtype, 2349 name=var._shared_name, 2350 synchronization=var.synchronization, 2351 aggregation=var.aggregation, 2352 extra_handle_data=var.handle) 2353 new_variable._maybe_initialize_trackable() 2354 # pylint: enable=protected-access 2355 return new_variable 2356 2357 2358ops.NotDifferentiable("Assert") 2359ops.NotDifferentiable("VarIsInitializedOp") 2360ops.NotDifferentiable("VariableShape") 2361 2362 2363class VariableSpec(tensor_spec.DenseSpec): 2364 """Describes a tf.Variable.""" 2365 2366 __slots__ = ["trainable"] 2367 2368 value_type = property(lambda self: BaseResourceVariable) 2369 2370 def __init__(self, shape, dtype=dtypes.float32, trainable=True): 2371 super(VariableSpec, self).__init__(shape, dtype=dtype) 2372 self.trainable = trainable 2373 2374 def is_compatible_with(self, spec_or_value): 2375 return (isinstance(spec_or_value, (type(self), self.value_type)) and 2376 self.shape.is_compatible_with(spec_or_value.shape) and 2377 self.dtype == spec_or_value.dtype and 2378 self.trainable == spec_or_value.trainable) 2379 2380 @classmethod 2381 def from_value(cls, value): 2382 return cls(value.shape, dtype=value.dtype, trainable=value.trainable) 2383 2384 def _to_components(self, value): 2385 return value.handle 2386 2387 def _from_components(self, components): 2388 return BaseResourceVariable( 2389 trainable=self.trainable, 2390 shape=self.shape, 2391 dtype=self.dtype, 2392 handle=components) 2393 2394 @property 2395 def _component_specs(self): 2396 return tensor_spec.TensorSpec(self.shape, dtypes.resource) 2397 2398 def _from_compatible_tensor_list(self, tensor_list): 2399 assert len(tensor_list) == 1 2400 return tensor_list[0] 2401 2402 def _serialize(self): 2403 return self.shape, self.dtype, self.trainable 2404 2405 def __tf_tracing_type__(self, signature_context): 2406 return signature_context.make_reference_type(self, id(self)) 2407 2408 def __repr__(self): 2409 return (f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype}, " 2410 f"trainable={self.trainable})") 2411 2412 def __hash__(self): 2413 return hash((self.shape, self.dtype, self.trainable)) 2414 2415 def __eq__(self, other): 2416 return (type(self) is type(other) and self.shape == other.shape and 2417 self.dtype == other.dtype and self.trainable == other.trainable) 2418 2419 2420_pywrap_utils.RegisterType("VariableSpec", VariableSpec) 2421 2422 2423def write_object_proto_for_resource_variable(resource_variable, 2424 proto, 2425 options, 2426 enforce_naming=True): 2427 """Writes additional information of the variable into the SavedObject proto. 2428 2429 This allows users to define a `hook` to provide extra information of the 2430 variable to the SavedObject. 2431 2432 For example, DistributedVariable class would fill in components in the 2433 distributed context. 2434 2435 Args: 2436 resource_variable: A `ResourceVariable` or `DistributedValue` that has the 2437 information to be saved into the proto. 2438 proto: `SavedObject` proto to update. 2439 options: A `SaveOption` instance that configures save behavior. 2440 enforce_naming: A bool determining whether to check that names end in the 2441 expected string ':0' 2442 """ 2443 proto.variable.SetInParent() 2444 if enforce_naming and not resource_variable.name.endswith(":0"): 2445 raise ValueError(f"Cowardly refusing to save variable " 2446 f"{resource_variable.name} because of " 2447 f"unexpected suffix in the name (expected ':0')" 2448 f"which won't be restored.") 2449 proto.variable.name = meta_graph._op_name(resource_variable.name) # pylint: disable=protected-access 2450 proto.variable.trainable = resource_variable.trainable 2451 proto.variable.dtype = resource_variable.dtype.as_datatype_enum 2452 proto.variable.synchronization = resource_variable.synchronization.value 2453 proto.variable.aggregation = resource_variable.aggregation.value 2454 proto.variable.shape.CopyFrom(resource_variable.shape.as_proto()) 2455 if options.experimental_variable_policy._save_variable_devices( # pylint: disable=protected-access 2456 ): 2457 if hasattr(resource_variable, "device"): 2458 proto.variable.device = resource_variable.device 2459