1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Variables. 17 18See the [Variables](https://www.tensorflow.org/guide/variables) guide. 19""" 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gen_math_ops 25from tensorflow.python.ops import gen_resource_variable_ops 26from tensorflow.python.ops import gen_state_ops 27# go/tf-wildcard-import 28# pylint: disable=wildcard-import 29from tensorflow.python.ops.gen_state_ops import * 30# pylint: enable=wildcard-import 31from tensorflow.python.util import deprecation 32from tensorflow.python.util.deprecation import deprecated 33from tensorflow.python.util.tf_export import tf_export 34 35 36# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args 37def variable_op(shape, dtype, name="Variable", set_shape=True, container="", 38 shared_name=""): 39 """Deprecated. Used variable_op_v2 instead.""" 40 if not set_shape: 41 shape = tensor_shape.unknown_shape() 42 ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name, 43 container=container, shared_name=shared_name) 44 # TODO(mrry): Move this to where it is used, so we can get rid of this op 45 # wrapper? 46 if set_shape: 47 ret.set_shape(shape) 48 return ret 49 50 51def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""): 52 """Create a variable Operation. 53 54 See also variables.Variable. 55 56 Args: 57 shape: The shape of the tensor managed by this variable 58 dtype: The underlying type of the tensor values. 59 name: optional name to use for the variable op. 60 container: An optional string. Defaults to "". 61 If non-empty, this variable is placed in the given container. 62 Otherwise, a default container is used. 63 shared_name: An optional string. Defaults to "". 64 If non-empty, this variable is named in the given bucket 65 with this shared_name. Otherwise, the node name is used instead. 66 67 Returns: 68 A variable tensor. 69 """ 70 return gen_state_ops.variable_v2( 71 shape=shape, 72 dtype=dtype, 73 name=name, 74 container=container, 75 shared_name=shared_name) 76 77 78def init_variable(v, init, name="init"): 79 """Initializes variable with "init". 80 81 This op does the following: 82 if init is a Tensor, v = init 83 if callable(init): v = init(VariableShape(v), v.dtype) 84 85 Args: 86 v: Variable to initialize 87 init: Tensor to assign to v, 88 Or an object convertible to Tensor e.g. nparray, 89 Or an Initializer that generates a tensor given the shape and type of v. 90 An "Initializer" is a callable that returns a tensor that "v" should be 91 set to. It will be called as init(shape, dtype). 92 name: Optional name for the op. 93 94 Returns: 95 The operation that initializes v. 96 """ 97 with ops.name_scope(None, v.op.name + "/", [v, init]): 98 with ops.name_scope(name) as scope: 99 with ops.colocate_with(v): 100 if callable(init): 101 assert v.get_shape().is_fully_defined(), "Variable shape unknown." 102 # TODO(mrry): Convert to v.shape when the property and 103 # accessor are reconciled (and all initializers support 104 # tf.TensorShape objects). 105 value = init(v.get_shape().as_list(), v.dtype.base_dtype) 106 value = ops.convert_to_tensor(value, name="value") 107 return gen_state_ops.assign(v, value, name=scope) 108 else: 109 init = ops.convert_to_tensor(init, name="init") 110 return gen_state_ops.assign(v, init, name=scope) 111 112 113def is_variable_initialized(ref, name=None): 114 """Checks whether a tensor has been initialized. 115 116 Outputs boolean scalar indicating whether the tensor has been initialized. 117 118 Args: 119 ref: A mutable `Tensor`. 120 Should be from a `Variable` node. May be uninitialized. 121 name: A name for the operation (optional). 122 123 Returns: 124 A `Tensor` of type `bool`. 125 """ 126 if ref.dtype._is_ref_dtype: 127 return gen_state_ops.is_variable_initialized(ref=ref, name=name) 128 # Handle resource variables. 129 return ref.is_initialized(name=name) 130 131 132@tf_export(v1=["assign_sub"]) 133def assign_sub(ref, value, use_locking=None, name=None): 134 """Update `ref` by subtracting `value` from it. 135 136 This operation outputs `ref` after the update is done. 137 This makes it easier to chain operations that need to use the reset value. 138 Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value` 139 must have the same shape. 140 141 Args: 142 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 143 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 144 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 145 from a `Variable` node. 146 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 147 be subtracted to the variable. 148 use_locking: An optional `bool`. Defaults to `False`. If True, the 149 subtraction will be protected by a lock; otherwise the behavior is 150 undefined, but may exhibit less contention. 151 name: A name for the operation (optional). 152 153 Returns: 154 Same as `ref`. Returned as a convenience for operations that want 155 to use the new value after the variable has been updated. 156 157 @compatibility(TF2) 158 `tf.compat.v1.assign_sub` is mostly compatible with eager 159 execution and `tf.function`. 160 161 To switch to the native TF2 style, one could use method 'assign_sub' of 162 `tf.Variable`: 163 164 #### How to Map Arguments 165 166 | TF1 Arg Name | TF2 Arg Name | Note | 167 | :-------------------- | :-------------- | :------------------------- | 168 | `ref` | `self` | In `assign_sub()` method | 169 | `value` | `value` | In `assign_sub()` method | 170 | `use_locking` | `use_locking` | In `assign_sub()` method | 171 | `name` | `name` | In `assign_sub()` method | 172 | - | `read_value` | Set to True to replicate | 173 : : : behavior (True is default) : 174 175 176 #### Before & After Usage Example 177 178 Before: 179 180 >>> with tf.Graph().as_default(): 181 ... with tf.compat.v1.Session() as sess: 182 ... a = tf.compat.v1.Variable(1, dtype=tf.int64) 183 ... sess.run(a.initializer) 184 ... update_op = tf.compat.v1.assign_sub(a, 1) 185 ... res_a = sess.run(update_op) 186 ... res_a 187 0 188 189 After: 190 191 >>> b = tf.Variable(1, dtype=tf.int64) 192 >>> res_b = b.assign_sub(1) 193 >>> res_b.numpy() 194 0 195 196 @end_compatibility 197 """ 198 if ref.dtype._is_ref_dtype: 199 return gen_state_ops.assign_sub( 200 ref, value, use_locking=use_locking, name=name) 201 return ref.assign_sub(value) 202 203 204@tf_export(v1=["assign_add"]) 205def assign_add(ref, value, use_locking=None, name=None): 206 """Update `ref` by adding `value` to it. 207 208 This operation outputs `ref` after the update is done. 209 This makes it easier to chain operations that need to use the reset value. 210 Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have 211 the same shape. 212 213 Args: 214 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 215 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 216 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 217 from a `Variable` node. 218 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 219 be added to the variable. 220 use_locking: An optional `bool`. Defaults to `False`. If True, the addition 221 will be protected by a lock; otherwise the behavior is undefined, but may 222 exhibit less contention. 223 name: A name for the operation (optional). 224 225 Returns: 226 Same as `ref`. Returned as a convenience for operations that want 227 to use the new value after the variable has been updated. 228 229 @compatibility(TF2) 230 `tf.compat.v1.assign_add` is mostly compatible with eager 231 execution and `tf.function`. 232 233 To switch to the native TF2 style, one could use method 'assign_add' of 234 `tf.Variable`: 235 236 #### How to Map Arguments 237 238 | TF1 Arg Name | TF2 Arg Name | Note | 239 | :-------------------- | :-------------- | :------------------------- | 240 | `ref` | `self` | In `assign_add()` method | 241 | `value` | `value` | In `assign_add()` method | 242 | `use_locking` | `use_locking` | In `assign_add()` method | 243 | `name` | `name` | In `assign_add()` method | 244 | - | `read_value` | Set to True to replicate | 245 : : : behavior (True is default) : 246 247 248 #### Before & After Usage Example 249 250 Before: 251 252 >>> with tf.Graph().as_default(): 253 ... with tf.compat.v1.Session() as sess: 254 ... a = tf.compat.v1.Variable(0, dtype=tf.int64) 255 ... sess.run(a.initializer) 256 ... update_op = tf.compat.v1.assign_add(a, 1) 257 ... res_a = sess.run(update_op) 258 ... res_a 259 1 260 261 After: 262 263 >>> b = tf.Variable(0, dtype=tf.int64) 264 >>> res_b = b.assign_add(1) 265 >>> res_b.numpy() 266 1 267 268 @end_compatibility 269 """ 270 if ref.dtype._is_ref_dtype: 271 return gen_state_ops.assign_add( 272 ref, value, use_locking=use_locking, name=name) 273 return ref.assign_add(value) 274 275 276@tf_export(v1=["assign"]) 277def assign(ref, value, validate_shape=None, use_locking=None, name=None): 278 """Update `ref` by assigning `value` to it. 279 280 This operation outputs a Tensor that holds the new value of `ref` after 281 the value has been assigned. This makes it easier to chain operations that 282 need to use the reset value. 283 284 Args: 285 ref: A mutable `Tensor`. Should be from a `Variable` node. May be 286 uninitialized. 287 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 288 be assigned to the variable. 289 validate_shape: An optional `bool`. Defaults to `True`. If true, the 290 operation will validate that the shape of 'value' matches the shape of the 291 Tensor being assigned to. If false, 'ref' will take on the shape of 292 'value'. 293 use_locking: An optional `bool`. Defaults to `True`. If True, the assignment 294 will be protected by a lock; otherwise the behavior is undefined, but may 295 exhibit less contention. 296 name: A name for the operation (optional). 297 298 Returns: 299 A `Tensor` that will hold the new value of `ref` after 300 the assignment has completed. 301 302 @compatibility(TF2) 303 `tf.compat.v1.assign` is mostly compatible with eager 304 execution and `tf.function`. However, argument 'validate_shape' will be 305 ignored. To avoid shape validation, set 'shape' to tf.TensorShape(None) when 306 constructing the variable: 307 308 >>> import tensorflow as tf 309 >>> a = tf.Variable([1], shape=tf.TensorShape(None)) 310 >>> tf.compat.v1.assign(a, [2,3]) 311 312 To switch to the native TF2 style, one could use method 'assign' of 313 `tf.Variable`: 314 315 #### How to Map Arguments 316 317 | TF1 Arg Name | TF2 Arg Name | Note | 318 | :-------------------- | :-------------- | :------------------------- | 319 | `ref` | `self` | In `assign()` method | 320 | `value` | `value` | In `assign()` method | 321 | `validate_shape` | Not supported | Specify `shape` in the | 322 : : : constructor to replicate : 323 : : : behavior : 324 | `use_locking` | `use_locking` | In `assign()` method | 325 | `name` | `name` | In `assign()` method | 326 | - | `read_value` | Set to True to replicate | 327 : : : behavior (True is default) : 328 @end_compatibility 329 330 331 #### Before & After Usage Example 332 333 Before: 334 335 >>> with tf.Graph().as_default(): 336 ... with tf.compat.v1.Session() as sess: 337 ... a = tf.compat.v1.Variable(0, dtype=tf.int64) 338 ... sess.run(a.initializer) 339 ... update_op = tf.compat.v1.assign(a, 2) 340 ... res_a = sess.run(update_op) 341 ... res_a 342 2 343 344 After: 345 346 >>> b = tf.Variable(0, dtype=tf.int64) 347 >>> res_b = b.assign(2) 348 >>> res_b.numpy() 349 2 350 """ 351 if ref.dtype._is_ref_dtype: 352 return gen_state_ops.assign( 353 ref, value, use_locking=use_locking, name=name, 354 validate_shape=validate_shape) 355 return ref.assign(value, name=name) 356 357 358@tf_export(v1=["count_up_to"]) 359@deprecated(None, "Prefer Dataset.range instead.") 360def count_up_to(ref, limit, name=None): 361 r"""Increments 'ref' until it reaches 'limit'. 362 363 Args: 364 ref: A Variable. Must be one of the following types: `int32`, `int64`. 365 Should be from a scalar `Variable` node. 366 limit: An `int`. 367 If incrementing ref would bring it above limit, instead generates an 368 'OutOfRange' error. 369 name: A name for the operation (optional). 370 371 Returns: 372 A `Tensor`. Has the same type as `ref`. 373 A copy of the input before increment. If nothing else modifies the 374 input, the values produced will all be distinct. 375 """ 376 if ref.dtype._is_ref_dtype: 377 return gen_state_ops.count_up_to(ref, limit=limit, name=name) 378 return gen_state_ops.resource_count_up_to( 379 ref.handle, limit, T=ref.dtype, name=name) 380 381 382@tf_export(v1=["scatter_update"]) 383def scatter_update(ref, indices, updates, use_locking=True, name=None): 384 # pylint: disable=line-too-long 385 r"""Applies sparse updates to a variable reference. 386 387 This operation computes 388 389 ```python 390 # Scalar indices 391 ref[indices, ...] = updates[...] 392 393 # Vector indices (for each i) 394 ref[indices[i], ...] = updates[i, ...] 395 396 # High rank indices (for each i, ..., j) 397 ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] 398 ``` 399 400 This operation outputs `ref` after the update is done. 401 This makes it easier to chain operations that need to use the reset value. 402 403 If values in `ref` is to be updated more than once, because there are 404 duplicate entries in `indices`, the order at which the updates happen 405 for each value is undefined. 406 407 Requires `updates.shape = indices.shape + ref.shape[1:]`. 408 409 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 410 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt> 411 </div> 412 413 Args: 414 ref: A `Variable`. 415 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 416 A tensor of indices into the first dimension of `ref`. 417 updates: A `Tensor`. Must have the same type as `ref`. 418 A tensor of updated values to store in `ref`. 419 use_locking: An optional `bool`. Defaults to `True`. 420 If True, the assignment will be protected by a lock; 421 otherwise the behavior is undefined, but may exhibit less contention. 422 name: A name for the operation (optional). 423 424 Returns: 425 Same as `ref`. Returned as a convenience for operations that want 426 to use the updated values after the update is done. 427 """ 428 if ref.dtype._is_ref_dtype: 429 return gen_state_ops.scatter_update(ref, indices, updates, 430 use_locking=use_locking, name=name) 431 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access 432 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 433 name=name)) 434 435 436@tf_export(v1=["scatter_nd_update"]) 437def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): 438 r"""Applies sparse `updates` to individual values or slices in a Variable. 439 440 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 441 442 `indices` must be integer tensor, containing indices into `ref`. 443 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 444 445 The innermost dimension of `indices` (with length `K`) corresponds to 446 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 447 dimension of `ref`. 448 449 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 450 451 ``` 452 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 453 ``` 454 455 For example, say we want to update 4 scattered elements to a rank-1 tensor to 456 8 elements. In Python, that update would look like this: 457 458 ```python 459 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 460 indices = tf.constant([[4], [3], [1] ,[7]]) 461 updates = tf.constant([9, 10, 11, 12]) 462 update = tf.compat.v1.scatter_nd_update(ref, indices, updates) 463 with tf.compat.v1.Session() as sess: 464 print sess.run(update) 465 ``` 466 467 The resulting update to ref would look like this: 468 469 [1, 11, 3, 10, 9, 6, 7, 12] 470 471 See `tf.scatter_nd` for more details about how to make updates to 472 slices. 473 474 Args: 475 ref: A Variable. 476 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 477 A tensor of indices into ref. 478 updates: A `Tensor`. Must have the same type as `ref`. 479 A Tensor. Must have the same type as ref. A tensor of updated 480 values to add to ref. 481 use_locking: An optional `bool`. Defaults to `True`. 482 An optional bool. Defaults to True. If True, the assignment will 483 be protected by a lock; otherwise the behavior is undefined, 484 but may exhibit less contention. 485 name: A name for the operation (optional). 486 487 Returns: 488 The value of the variable after the update. 489 """ 490 if ref.dtype._is_ref_dtype: 491 return gen_state_ops.scatter_nd_update( 492 ref, indices, updates, use_locking, name) 493 return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access 494 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 495 name=name)) 496 497 498@tf_export(v1=["scatter_add"]) 499def scatter_add(ref, indices, updates, use_locking=False, name=None): 500 # pylint: disable=line-too-long 501 r"""Adds sparse updates to the variable referenced by `resource`. 502 503 This operation computes 504 505 ```python 506 # Scalar indices 507 ref[indices, ...] += updates[...] 508 509 # Vector indices (for each i) 510 ref[indices[i], ...] += updates[i, ...] 511 512 # High rank indices (for each i, ..., j) 513 ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] 514 ``` 515 516 This operation outputs `ref` after the update is done. 517 This makes it easier to chain operations that need to use the updated value. 518 Duplicate entries are handled correctly: if multiple `indices` reference 519 the same location, their contributions add. 520 521 Requires `updates.shape = indices.shape + ref.shape[1:]`. 522 523 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 524 <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> 525 </div> 526 527 Args: 528 ref: A `Variable`. 529 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 530 A tensor of indices into the first dimension of `ref`. 531 updates: A `Tensor`. Must have the same type as `ref`. 532 A tensor of updated values to store in `ref`. 533 use_locking: An optional `bool`. Defaults to `False`. 534 If True, the assignment will be protected by a lock; 535 otherwise the behavior is undefined, but may exhibit less contention. 536 name: A name for the operation (optional). 537 538 Returns: 539 Same as `ref`. Returned as a convenience for operations that want 540 to use the updated values after the update is done. 541 """ 542 if ref.dtype._is_ref_dtype: 543 return gen_state_ops.scatter_add(ref, indices, updates, 544 use_locking=use_locking, name=name) 545 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access 546 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 547 name=name)) 548 549 550@tf_export(v1=["scatter_nd_add"]) 551def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): 552 r"""Applies sparse addition to individual values or slices in a Variable. 553 554 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 555 556 `indices` must be integer tensor, containing indices into `ref`. 557 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 558 559 The innermost dimension of `indices` (with length `K`) corresponds to 560 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 561 dimension of `ref`. 562 563 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 564 565 ``` 566 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 567 ``` 568 569 For example, say we want to add 4 scattered elements to a rank-1 tensor to 570 8 elements. In Python, that addition would look like this: 571 572 ```python 573 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 574 indices = tf.constant([[4], [3], [1], [7]]) 575 updates = tf.constant([9, 10, 11, 12]) 576 add = tf.compat.v1.scatter_nd_add(ref, indices, updates) 577 with tf.compat.v1.Session() as sess: 578 print sess.run(add) 579 ``` 580 581 The resulting update to ref would look like this: 582 583 [1, 13, 3, 14, 14, 6, 7, 20] 584 585 See `tf.scatter_nd` for more details about how to make updates to 586 slices. 587 588 Args: 589 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 590 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 591 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 592 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 593 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 594 A tensor of indices into ref. 595 updates: A `Tensor`. Must have the same type as `ref`. 596 A tensor of updated values to add to ref. 597 use_locking: An optional `bool`. Defaults to `False`. 598 If True, the assignment will be protected by a lock; 599 otherwise the behavior is undefined, but may exhibit less contention. 600 name: A name for the operation (optional). 601 602 Returns: 603 A mutable `Tensor`. Has the same type as `ref`. 604 """ 605 if ref.dtype._is_ref_dtype: 606 return gen_state_ops.scatter_nd_add( 607 ref, indices, updates, use_locking, name) 608 return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access 609 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 610 name=name)) 611 612 613@tf_export(v1=["scatter_sub"]) 614def scatter_sub(ref, indices, updates, use_locking=False, name=None): 615 r"""Subtracts sparse updates to a variable reference. 616 617 ```python 618 # Scalar indices 619 ref[indices, ...] -= updates[...] 620 621 # Vector indices (for each i) 622 ref[indices[i], ...] -= updates[i, ...] 623 624 # High rank indices (for each i, ..., j) 625 ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] 626 ``` 627 628 This operation outputs `ref` after the update is done. 629 This makes it easier to chain operations that need to use the reset value. 630 631 Duplicate entries are handled correctly: if multiple `indices` reference 632 the same location, their (negated) contributions add. 633 634 Requires `updates.shape = indices.shape + ref.shape[1:]` or 635 `updates.shape = []`. 636 637 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 638 <img style="width:100%" 639 src="https://www.tensorflow.org/images/ScatterSub.png" alt> 640 </div> 641 642 Args: 643 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 644 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 645 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 646 `uint32`, `uint64`. Should be from a `Variable` node. 647 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 648 A tensor of indices into the first dimension of `ref`. 649 updates: A `Tensor`. Must have the same type as `ref`. 650 A tensor of updated values to subtract from `ref`. 651 use_locking: An optional `bool`. Defaults to `False`. 652 If True, the subtraction will be protected by a lock; 653 otherwise the behavior is undefined, but may exhibit less contention. 654 name: A name for the operation (optional). 655 656 Returns: 657 A mutable `Tensor`. Has the same type as `ref`. 658 """ 659 if ref.dtype._is_ref_dtype: 660 return gen_state_ops.scatter_sub(ref, indices, updates, 661 use_locking=use_locking, name=name) 662 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access 663 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 664 name=name)) 665 666 667@tf_export(v1=["scatter_nd_sub"]) 668def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None): 669 r"""Applies sparse subtraction to individual values or slices in a Variable. 670 671 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 672 673 `indices` must be integer tensor, containing indices into `ref`. 674 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 675 676 The innermost dimension of `indices` (with length `K`) corresponds to 677 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 678 dimension of `ref`. 679 680 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 681 682 ``` 683 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 684 ``` 685 686 For example, say we want to subtract 4 scattered elements from a rank-1 tensor 687 with 8 elements. In Python, that update would look like this: 688 689 ```python 690 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 691 indices = tf.constant([[4], [3], [1] ,[7]]) 692 updates = tf.constant([9, 10, 11, 12]) 693 op = tf.compat.v1.scatter_nd_sub(ref, indices, updates) 694 with tf.compat.v1.Session() as sess: 695 print sess.run(op) 696 ``` 697 698 The resulting update to ref would look like this: 699 700 [1, -9, 3, -6, -6, 6, 7, -4] 701 702 See `tf.scatter_nd` for more details about how to make updates to 703 slices. 704 705 Args: 706 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 707 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 708 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 709 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 710 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 711 A tensor of indices into ref. 712 updates: A `Tensor`. Must have the same type as `ref`. 713 A tensor of updated values to add to ref. 714 use_locking: An optional `bool`. Defaults to `False`. 715 An optional bool. Defaults to True. If True, the assignment will 716 be protected by a lock; otherwise the behavior is undefined, 717 but may exhibit less contention. 718 name: A name for the operation (optional). 719 720 Returns: 721 A mutable `Tensor`. Has the same type as `ref`. 722 """ 723 if ref.dtype._is_ref_dtype: 724 return gen_state_ops.scatter_nd_sub( 725 ref, indices, updates, use_locking, name) 726 return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access 727 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 728 name=name)) 729 730 731@tf_export(v1=["scatter_mul"]) 732def scatter_mul(ref, indices, updates, use_locking=False, name=None): 733 # pylint: disable=line-too-long 734 r"""Multiplies sparse updates into a variable reference. 735 736 This operation computes 737 738 ```python 739 # Scalar indices 740 ref[indices, ...] *= updates[...] 741 742 # Vector indices (for each i) 743 ref[indices[i], ...] *= updates[i, ...] 744 745 # High rank indices (for each i, ..., j) 746 ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] 747 ``` 748 749 This operation outputs `ref` after the update is done. 750 This makes it easier to chain operations that need to use the reset value. 751 752 Duplicate entries are handled correctly: if multiple `indices` reference 753 the same location, their contributions multiply. 754 755 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 756 []`. 757 758 Args: 759 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 760 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 761 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 762 `uint32`, `uint64`. Should be from a `Variable` node. 763 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 764 tensor of indices into the first dimension of `ref`. 765 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 766 values to multiply to `ref`. 767 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 768 will be protected by a lock; otherwise the behavior is undefined, but may 769 exhibit less contention. 770 name: A name for the operation (optional). 771 772 Returns: 773 A mutable `Tensor`. Has the same type as `ref`. 774 """ 775 if ref.dtype._is_ref_dtype: 776 return gen_state_ops.scatter_mul(ref, indices, updates, 777 use_locking=use_locking, name=name) 778 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access 779 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 780 name=name)) 781 782 783@tf_export(v1=["scatter_div"]) 784def scatter_div(ref, indices, updates, use_locking=False, name=None): 785 # pylint: disable=line-too-long 786 r"""Divides a variable reference by sparse updates. 787 788 This operation computes 789 790 ```python 791 # Scalar indices 792 ref[indices, ...] /= updates[...] 793 794 # Vector indices (for each i) 795 ref[indices[i], ...] /= updates[i, ...] 796 797 # High rank indices (for each i, ..., j) 798 ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] 799 ``` 800 801 This operation outputs `ref` after the update is done. 802 This makes it easier to chain operations that need to use the reset value. 803 804 Duplicate entries are handled correctly: if multiple `indices` reference 805 the same location, their contributions divide. 806 807 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 808 []`. 809 810 Args: 811 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 812 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 813 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 814 `uint32`, `uint64`. Should be from a `Variable` node. 815 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 816 tensor of indices into the first dimension of `ref`. 817 updates: A `Tensor`. Must have the same type as `ref`. A tensor of values 818 that `ref` is divided by. 819 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 820 will be protected by a lock; otherwise the behavior is undefined, but may 821 exhibit less contention. 822 name: A name for the operation (optional). 823 824 Returns: 825 A mutable `Tensor`. Has the same type as `ref`. 826 """ 827 if ref.dtype._is_ref_dtype: 828 return gen_state_ops.scatter_div(ref, indices, updates, 829 use_locking=use_locking, name=name) 830 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access 831 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 832 name=name)) 833 834 835@tf_export(v1=["scatter_max"]) 836def scatter_max(ref, indices, updates, use_locking=False, name=None): 837 # pylint: disable=line-too-long 838 r"""Reduces sparse updates into a variable reference using the `max` operation. 839 840 This operation computes 841 842 # Scalar indices 843 ref[indices, ...] = max(ref[indices, ...], updates[...]) 844 845 # Vector indices (for each i) 846 ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) 847 848 # High rank indices (for each i, ..., j) 849 ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], 850 updates[i, ..., j, ...]) 851 852 This operation outputs `ref` after the update is done. 853 This makes it easier to chain operations that need to use the reset value. 854 855 Duplicate entries are handled correctly: if multiple `indices` reference 856 the same location, their contributions combine. 857 858 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 859 []`. 860 861 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 862 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 863 alt> 864 </div> 865 866 Args: 867 ref: A mutable `Tensor`. Must be one of the following types: `half`, 868 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 869 `Variable` node. 870 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 871 tensor of indices into the first dimension of `ref`. 872 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 873 values to reduce into `ref`. 874 use_locking: An optional `bool`. Defaults to `False`. If True, the update 875 will be protected by a lock; otherwise the behavior is undefined, but may 876 exhibit less contention. 877 name: A name for the operation (optional). 878 879 Returns: 880 A mutable `Tensor`. Has the same type as `ref`. 881 """ 882 if ref.dtype._is_ref_dtype: 883 return gen_state_ops.scatter_max(ref, indices, updates, 884 use_locking=use_locking, name=name) 885 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access 886 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 887 name=name)) 888 889 890@tf_export(v1=["scatter_min"]) 891def scatter_min(ref, indices, updates, use_locking=False, name=None): 892 # pylint: disable=line-too-long 893 r"""Reduces sparse updates into a variable reference using the `min` operation. 894 895 This operation computes 896 897 # Scalar indices 898 ref[indices, ...] = min(ref[indices, ...], updates[...]) 899 900 # Vector indices (for each i) 901 ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) 902 903 # High rank indices (for each i, ..., j) 904 ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], 905 updates[i, ..., j, ...]) 906 907 This operation outputs `ref` after the update is done. 908 This makes it easier to chain operations that need to use the reset value. 909 910 Duplicate entries are handled correctly: if multiple `indices` reference 911 the same location, their contributions combine. 912 913 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 914 []`. 915 916 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 917 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 918 alt> 919 </div> 920 921 Args: 922 ref: A mutable `Tensor`. Must be one of the following types: `half`, 923 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 924 `Variable` node. 925 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 926 tensor of indices into the first dimension of `ref`. 927 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 928 values to reduce into `ref`. 929 use_locking: An optional `bool`. Defaults to `False`. If True, the update 930 will be protected by a lock; otherwise the behavior is undefined, but may 931 exhibit less contention. 932 name: A name for the operation (optional). 933 934 Returns: 935 A mutable `Tensor`. Has the same type as `ref`. 936 """ 937 if ref.dtype._is_ref_dtype: 938 return gen_state_ops.scatter_min(ref, indices, updates, 939 use_locking=use_locking, name=name) 940 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access 941 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 942 name=name)) 943 944 945@tf_export(v1=["batch_scatter_update"]) 946@deprecation.deprecated( 947 "2018-11-29", "Use the batch_scatter_update method of Variable instead.") 948def batch_scatter_update(ref, indices, updates, use_locking=True, name=None): 949 """Generalization of `tf.compat.v1.scatter_update` to axis different than 0. 950 951 Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates` 952 have a series of leading dimensions that are the same for all of them, and the 953 updates are performed on the last dimension of indices. In other words, the 954 dimensions should be the following: 955 956 `num_prefix_dims = indices.ndims - 1` 957 `batch_dim = num_prefix_dims + 1` 958 `updates.shape = indices.shape + var.shape[batch_dim:]` 959 960 where 961 962 `updates.shape[:num_prefix_dims]` 963 `== indices.shape[:num_prefix_dims]` 964 `== var.shape[:num_prefix_dims]` 965 966 And the operation performed can be expressed as: 967 968 `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]` 969 970 When indices is a 1D tensor, this operation is equivalent to 971 `tf.compat.v1.scatter_update`. 972 973 To avoid this operation there would be 2 alternatives: 974 1) Reshaping the variable by merging the first `ndims` dimensions. However, 975 this is not possible because `tf.reshape` returns a Tensor, which we 976 cannot use `tf.compat.v1.scatter_update` on. 977 2) Looping over the first `ndims` of the variable and using 978 `tf.compat.v1.scatter_update` on the subtensors that result of slicing the 979 first 980 dimension. This is a valid option for `ndims = 1`, but less efficient than 981 this implementation. 982 983 See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`. 984 985 Args: 986 ref: `Variable` to scatter onto. 987 indices: Tensor containing indices as described above. 988 updates: Tensor of updates to apply to `ref`. 989 use_locking: Boolean indicating whether to lock the writing operation. 990 name: Optional scope name string. 991 992 Returns: 993 Ref to `variable` after it has been modified. 994 995 Raises: 996 ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are 997 not the same. 998 """ 999 with ops.name_scope(name): 1000 indices = ops.convert_to_tensor(indices, name="indices") 1001 indices_shape = array_ops.shape(indices) 1002 indices_dimensions = indices.get_shape().ndims 1003 1004 if indices_dimensions is None: 1005 raise ValueError("batch_gather does not allow indices with unknown " 1006 "shape.") 1007 1008 nd_indices = array_ops.expand_dims(indices, axis=-1) 1009 nd_indices_list = [] 1010 1011 # Scatter ND requires indices to have an additional dimension, in which the 1012 # coordinates of the updated things are specified. For this to be adapted to 1013 # the scatter_update with several leading dimensions, we simply make use of 1014 # a tf.range for all the leading dimensions followed by concat of all the 1015 # coordinates we created with the original indices. 1016 1017 # For example if indices.shape = [2, 3, 4], we should generate the following 1018 # indices for tf.compat.v1.scatter_nd_update: 1019 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 1020 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 1021 # nd_indices[:, :, 2] = indices 1022 for dimension in range(indices_dimensions - 1): 1023 # In this loop we generate the following for the example (one for each 1024 # iteration). 1025 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 1026 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 1027 # This is done at every iteration with a tf.range over the size of the 1028 # i-th dimension and using broadcasting over the desired shape. 1029 dimension_size = indices_shape[dimension] 1030 shape_to_broadcast = [1] * (indices_dimensions + 1) 1031 shape_to_broadcast[dimension] = dimension_size 1032 dimension_range = array_ops.reshape( 1033 gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast) 1034 if dimension_range.dtype.base_dtype != nd_indices.dtype: 1035 dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype) 1036 nd_indices_list.append( 1037 dimension_range * array_ops.ones_like(nd_indices)) 1038 # Add the original indices at the end, as described above, and concat. 1039 nd_indices_list.append(nd_indices) 1040 final_indices = array_ops.concat(nd_indices_list, axis=-1) 1041 return scatter_nd_update( 1042 ref, final_indices, updates, use_locking=use_locking) 1043