xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/state_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Variables.
17
18See the [Variables](https://www.tensorflow.org/guide/variables) guide.
19"""
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_math_ops
25from tensorflow.python.ops import gen_resource_variable_ops
26from tensorflow.python.ops import gen_state_ops
27# go/tf-wildcard-import
28# pylint: disable=wildcard-import
29from tensorflow.python.ops.gen_state_ops import *
30# pylint: enable=wildcard-import
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.deprecation import deprecated
33from tensorflow.python.util.tf_export import tf_export
34
35
36# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args
37def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
38                shared_name=""):
39  """Deprecated. Used variable_op_v2 instead."""
40  if not set_shape:
41    shape = tensor_shape.unknown_shape()
42  ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name,
43                               container=container, shared_name=shared_name)
44  # TODO(mrry): Move this to where it is used, so we can get rid of this op
45  #   wrapper?
46  if set_shape:
47    ret.set_shape(shape)
48  return ret
49
50
51def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""):
52  """Create a variable Operation.
53
54  See also variables.Variable.
55
56  Args:
57    shape: The shape of the tensor managed by this variable
58    dtype: The underlying type of the tensor values.
59    name: optional name to use for the variable op.
60    container: An optional string. Defaults to "".
61      If non-empty, this variable is placed in the given container.
62      Otherwise, a default container is used.
63    shared_name: An optional string. Defaults to "".
64      If non-empty, this variable is named in the given bucket
65      with this shared_name. Otherwise, the node name is used instead.
66
67  Returns:
68    A variable tensor.
69  """
70  return gen_state_ops.variable_v2(
71      shape=shape,
72      dtype=dtype,
73      name=name,
74      container=container,
75      shared_name=shared_name)
76
77
78def init_variable(v, init, name="init"):
79  """Initializes variable with "init".
80
81  This op does the following:
82  if init is a Tensor, v = init
83  if callable(init): v = init(VariableShape(v), v.dtype)
84
85  Args:
86    v: Variable to initialize
87    init: Tensor to assign to v,
88      Or an object convertible to Tensor e.g. nparray,
89      Or an Initializer that generates a tensor given the shape and type of v.
90      An "Initializer" is a callable that returns a tensor that "v" should be
91      set to. It will be called as init(shape, dtype).
92    name: Optional name for the op.
93
94  Returns:
95    The operation that initializes v.
96  """
97  with ops.name_scope(None, v.op.name + "/", [v, init]):
98    with ops.name_scope(name) as scope:
99      with ops.colocate_with(v):
100        if callable(init):
101          assert v.get_shape().is_fully_defined(), "Variable shape unknown."
102          # TODO(mrry): Convert to v.shape when the property and
103          # accessor are reconciled (and all initializers support
104          # tf.TensorShape objects).
105          value = init(v.get_shape().as_list(), v.dtype.base_dtype)
106          value = ops.convert_to_tensor(value, name="value")
107          return gen_state_ops.assign(v, value, name=scope)
108        else:
109          init = ops.convert_to_tensor(init, name="init")
110          return gen_state_ops.assign(v, init, name=scope)
111
112
113def is_variable_initialized(ref, name=None):
114  """Checks whether a tensor has been initialized.
115
116  Outputs boolean scalar indicating whether the tensor has been initialized.
117
118  Args:
119    ref: A mutable `Tensor`.
120      Should be from a `Variable` node. May be uninitialized.
121    name: A name for the operation (optional).
122
123  Returns:
124    A `Tensor` of type `bool`.
125  """
126  if ref.dtype._is_ref_dtype:
127    return gen_state_ops.is_variable_initialized(ref=ref, name=name)
128  # Handle resource variables.
129  return ref.is_initialized(name=name)
130
131
132@tf_export(v1=["assign_sub"])
133def assign_sub(ref, value, use_locking=None, name=None):
134  """Update `ref` by subtracting `value` from it.
135
136  This operation outputs `ref` after the update is done.
137  This makes it easier to chain operations that need to use the reset value.
138  Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value`
139  must have the same shape.
140
141  Args:
142    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
143      `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
144      `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
145      from a `Variable` node.
146    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
147      be subtracted to the variable.
148    use_locking: An optional `bool`. Defaults to `False`. If True, the
149      subtraction will be protected by a lock; otherwise the behavior is
150      undefined, but may exhibit less contention.
151    name: A name for the operation (optional).
152
153  Returns:
154    Same as `ref`.  Returned as a convenience for operations that want
155    to use the new value after the variable has been updated.
156
157  @compatibility(TF2)
158  `tf.compat.v1.assign_sub` is mostly compatible with eager
159  execution and `tf.function`.
160
161  To switch to the native TF2 style, one could use method 'assign_sub' of
162  `tf.Variable`:
163
164  #### How to Map Arguments
165
166  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
167  | :-------------------- | :-------------- | :------------------------- |
168  | `ref`                 | `self`          | In `assign_sub()` method   |
169  | `value`               | `value`         | In `assign_sub()` method   |
170  | `use_locking`         | `use_locking`   | In `assign_sub()` method   |
171  | `name`                | `name`          | In `assign_sub()` method   |
172  | -                     | `read_value`    | Set to True to replicate   |
173  :                       :                 : behavior (True is default) :
174
175
176  #### Before & After Usage Example
177
178  Before:
179
180  >>> with tf.Graph().as_default():
181  ...   with tf.compat.v1.Session() as sess:
182  ...     a = tf.compat.v1.Variable(1, dtype=tf.int64)
183  ...     sess.run(a.initializer)
184  ...     update_op = tf.compat.v1.assign_sub(a, 1)
185  ...     res_a = sess.run(update_op)
186  ...     res_a
187  0
188
189  After:
190
191  >>> b = tf.Variable(1, dtype=tf.int64)
192  >>> res_b = b.assign_sub(1)
193  >>> res_b.numpy()
194  0
195
196  @end_compatibility
197  """
198  if ref.dtype._is_ref_dtype:
199    return gen_state_ops.assign_sub(
200        ref, value, use_locking=use_locking, name=name)
201  return ref.assign_sub(value)
202
203
204@tf_export(v1=["assign_add"])
205def assign_add(ref, value, use_locking=None, name=None):
206  """Update `ref` by adding `value` to it.
207
208  This operation outputs `ref` after the update is done.
209  This makes it easier to chain operations that need to use the reset value.
210  Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have
211  the same shape.
212
213  Args:
214    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
215      `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
216      `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
217      from a `Variable` node.
218    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
219      be added to the variable.
220    use_locking: An optional `bool`. Defaults to `False`. If True, the addition
221      will be protected by a lock; otherwise the behavior is undefined, but may
222      exhibit less contention.
223    name: A name for the operation (optional).
224
225  Returns:
226    Same as `ref`.  Returned as a convenience for operations that want
227    to use the new value after the variable has been updated.
228
229  @compatibility(TF2)
230  `tf.compat.v1.assign_add` is mostly compatible with eager
231  execution and `tf.function`.
232
233  To switch to the native TF2 style, one could use method 'assign_add' of
234  `tf.Variable`:
235
236  #### How to Map Arguments
237
238  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
239  | :-------------------- | :-------------- | :------------------------- |
240  | `ref`                 | `self`          | In `assign_add()` method   |
241  | `value`               | `value`         | In `assign_add()` method   |
242  | `use_locking`         | `use_locking`   | In `assign_add()` method   |
243  | `name`                | `name`          | In `assign_add()` method   |
244  | -                     | `read_value`    | Set to True to replicate   |
245  :                       :                 : behavior (True is default) :
246
247
248  #### Before & After Usage Example
249
250  Before:
251
252  >>> with tf.Graph().as_default():
253  ...   with tf.compat.v1.Session() as sess:
254  ...     a = tf.compat.v1.Variable(0, dtype=tf.int64)
255  ...     sess.run(a.initializer)
256  ...     update_op = tf.compat.v1.assign_add(a, 1)
257  ...     res_a = sess.run(update_op)
258  ...     res_a
259  1
260
261  After:
262
263  >>> b = tf.Variable(0, dtype=tf.int64)
264  >>> res_b = b.assign_add(1)
265  >>> res_b.numpy()
266  1
267
268  @end_compatibility
269  """
270  if ref.dtype._is_ref_dtype:
271    return gen_state_ops.assign_add(
272        ref, value, use_locking=use_locking, name=name)
273  return ref.assign_add(value)
274
275
276@tf_export(v1=["assign"])
277def assign(ref, value, validate_shape=None, use_locking=None, name=None):
278  """Update `ref` by assigning `value` to it.
279
280  This operation outputs a Tensor that holds the new value of `ref` after
281  the value has been assigned. This makes it easier to chain operations that
282  need to use the reset value.
283
284  Args:
285    ref: A mutable `Tensor`. Should be from a `Variable` node. May be
286      uninitialized.
287    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
288      be assigned to the variable.
289    validate_shape: An optional `bool`. Defaults to `True`. If true, the
290      operation will validate that the shape of 'value' matches the shape of the
291      Tensor being assigned to.  If false, 'ref' will take on the shape of
292      'value'.
293    use_locking: An optional `bool`. Defaults to `True`. If True, the assignment
294      will be protected by a lock; otherwise the behavior is undefined, but may
295      exhibit less contention.
296    name: A name for the operation (optional).
297
298  Returns:
299    A `Tensor` that will hold the new value of `ref` after
300      the assignment has completed.
301
302  @compatibility(TF2)
303  `tf.compat.v1.assign` is mostly compatible with eager
304  execution and `tf.function`. However, argument 'validate_shape' will be
305  ignored. To avoid shape validation, set 'shape' to tf.TensorShape(None) when
306  constructing the variable:
307
308  >>> import tensorflow as tf
309  >>> a = tf.Variable([1], shape=tf.TensorShape(None))
310  >>> tf.compat.v1.assign(a, [2,3])
311
312  To switch to the native TF2 style, one could use method 'assign' of
313  `tf.Variable`:
314
315  #### How to Map Arguments
316
317  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
318  | :-------------------- | :-------------- | :------------------------- |
319  | `ref`                 | `self`          | In `assign()` method       |
320  | `value`               | `value`         | In `assign()` method       |
321  | `validate_shape`      | Not supported   | Specify `shape` in the     |
322  :                       :                 : constructor to replicate   :
323  :                       :                 : behavior                   :
324  | `use_locking`         | `use_locking`   | In `assign()` method       |
325  | `name`                | `name`          | In `assign()` method       |
326  | -                     | `read_value`    | Set to True to replicate   |
327  :                       :                 : behavior (True is default) :
328  @end_compatibility
329
330
331  #### Before & After Usage Example
332
333  Before:
334
335  >>> with tf.Graph().as_default():
336  ...   with tf.compat.v1.Session() as sess:
337  ...     a = tf.compat.v1.Variable(0, dtype=tf.int64)
338  ...     sess.run(a.initializer)
339  ...     update_op = tf.compat.v1.assign(a, 2)
340  ...     res_a = sess.run(update_op)
341  ...     res_a
342  2
343
344  After:
345
346  >>> b = tf.Variable(0, dtype=tf.int64)
347  >>> res_b = b.assign(2)
348  >>> res_b.numpy()
349  2
350  """
351  if ref.dtype._is_ref_dtype:
352    return gen_state_ops.assign(
353        ref, value, use_locking=use_locking, name=name,
354        validate_shape=validate_shape)
355  return ref.assign(value, name=name)
356
357
358@tf_export(v1=["count_up_to"])
359@deprecated(None, "Prefer Dataset.range instead.")
360def count_up_to(ref, limit, name=None):
361  r"""Increments 'ref' until it reaches 'limit'.
362
363  Args:
364    ref: A Variable. Must be one of the following types: `int32`, `int64`.
365      Should be from a scalar `Variable` node.
366    limit: An `int`.
367      If incrementing ref would bring it above limit, instead generates an
368      'OutOfRange' error.
369    name: A name for the operation (optional).
370
371  Returns:
372    A `Tensor`. Has the same type as `ref`.
373    A copy of the input before increment. If nothing else modifies the
374    input, the values produced will all be distinct.
375  """
376  if ref.dtype._is_ref_dtype:
377    return gen_state_ops.count_up_to(ref, limit=limit, name=name)
378  return gen_state_ops.resource_count_up_to(
379      ref.handle, limit, T=ref.dtype, name=name)
380
381
382@tf_export(v1=["scatter_update"])
383def scatter_update(ref, indices, updates, use_locking=True, name=None):
384  # pylint: disable=line-too-long
385  r"""Applies sparse updates to a variable reference.
386
387  This operation computes
388
389  ```python
390      # Scalar indices
391      ref[indices, ...] = updates[...]
392
393      # Vector indices (for each i)
394      ref[indices[i], ...] = updates[i, ...]
395
396      # High rank indices (for each i, ..., j)
397      ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
398  ```
399
400  This operation outputs `ref` after the update is done.
401  This makes it easier to chain operations that need to use the reset value.
402
403  If values in `ref` is to be updated more than once, because there are
404  duplicate entries in `indices`, the order at which the updates happen
405  for each value is undefined.
406
407  Requires `updates.shape = indices.shape + ref.shape[1:]`.
408
409  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
410  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
411  </div>
412
413  Args:
414    ref: A `Variable`.
415    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
416      A tensor of indices into the first dimension of `ref`.
417    updates: A `Tensor`. Must have the same type as `ref`.
418      A tensor of updated values to store in `ref`.
419    use_locking: An optional `bool`. Defaults to `True`.
420      If True, the assignment will be protected by a lock;
421      otherwise the behavior is undefined, but may exhibit less contention.
422    name: A name for the operation (optional).
423
424  Returns:
425    Same as `ref`.  Returned as a convenience for operations that want
426    to use the updated values after the update is done.
427  """
428  if ref.dtype._is_ref_dtype:
429    return gen_state_ops.scatter_update(ref, indices, updates,
430                                        use_locking=use_locking, name=name)
431  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update(  # pylint: disable=protected-access
432      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
433      name=name))
434
435
436@tf_export(v1=["scatter_nd_update"])
437def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
438  r"""Applies sparse `updates` to individual values or slices in a Variable.
439
440  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
441
442  `indices` must be integer tensor, containing indices into `ref`.
443  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
444
445  The innermost dimension of `indices` (with length `K`) corresponds to
446  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
447  dimension of `ref`.
448
449  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
450
451  ```
452  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
453  ```
454
455  For example, say we want to update 4 scattered elements to a rank-1 tensor to
456  8 elements. In Python, that update would look like this:
457
458  ```python
459      ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
460      indices = tf.constant([[4], [3], [1] ,[7]])
461      updates = tf.constant([9, 10, 11, 12])
462      update = tf.compat.v1.scatter_nd_update(ref, indices, updates)
463      with tf.compat.v1.Session() as sess:
464        print sess.run(update)
465  ```
466
467  The resulting update to ref would look like this:
468
469      [1, 11, 3, 10, 9, 6, 7, 12]
470
471  See `tf.scatter_nd` for more details about how to make updates to
472  slices.
473
474  Args:
475    ref: A Variable.
476    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
477      A tensor of indices into ref.
478    updates: A `Tensor`. Must have the same type as `ref`.
479      A Tensor. Must have the same type as ref. A tensor of updated
480      values to add to ref.
481    use_locking: An optional `bool`. Defaults to `True`.
482      An optional bool. Defaults to True. If True, the assignment will
483      be protected by a lock; otherwise the behavior is undefined,
484      but may exhibit less contention.
485    name: A name for the operation (optional).
486
487  Returns:
488    The value of the variable after the update.
489  """
490  if ref.dtype._is_ref_dtype:
491    return gen_state_ops.scatter_nd_update(
492        ref, indices, updates, use_locking, name)
493  return ref._lazy_read(gen_state_ops.resource_scatter_nd_update(  # pylint: disable=protected-access
494      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
495      name=name))
496
497
498@tf_export(v1=["scatter_add"])
499def scatter_add(ref, indices, updates, use_locking=False, name=None):
500  # pylint: disable=line-too-long
501  r"""Adds sparse updates to the variable referenced by `resource`.
502
503  This operation computes
504
505  ```python
506      # Scalar indices
507      ref[indices, ...] += updates[...]
508
509      # Vector indices (for each i)
510      ref[indices[i], ...] += updates[i, ...]
511
512      # High rank indices (for each i, ..., j)
513      ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
514  ```
515
516  This operation outputs `ref` after the update is done.
517  This makes it easier to chain operations that need to use the updated value.
518  Duplicate entries are handled correctly: if multiple `indices` reference
519  the same location, their contributions add.
520
521  Requires `updates.shape = indices.shape + ref.shape[1:]`.
522
523  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
524  <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
525  </div>
526
527  Args:
528    ref: A `Variable`.
529    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
530      A tensor of indices into the first dimension of `ref`.
531    updates: A `Tensor`. Must have the same type as `ref`.
532      A tensor of updated values to store in `ref`.
533    use_locking: An optional `bool`. Defaults to `False`.
534      If True, the assignment will be protected by a lock;
535      otherwise the behavior is undefined, but may exhibit less contention.
536    name: A name for the operation (optional).
537
538  Returns:
539    Same as `ref`.  Returned as a convenience for operations that want
540    to use the updated values after the update is done.
541  """
542  if ref.dtype._is_ref_dtype:
543    return gen_state_ops.scatter_add(ref, indices, updates,
544                                     use_locking=use_locking, name=name)
545  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add(  # pylint: disable=protected-access
546      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
547      name=name))
548
549
550@tf_export(v1=["scatter_nd_add"])
551def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
552  r"""Applies sparse addition to individual values or slices in a Variable.
553
554  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
555
556  `indices` must be integer tensor, containing indices into `ref`.
557  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
558
559  The innermost dimension of `indices` (with length `K`) corresponds to
560  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
561  dimension of `ref`.
562
563  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
564
565  ```
566  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
567  ```
568
569  For example, say we want to add 4 scattered elements to a rank-1 tensor to
570  8 elements. In Python, that addition would look like this:
571
572  ```python
573  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
574  indices = tf.constant([[4], [3], [1], [7]])
575  updates = tf.constant([9, 10, 11, 12])
576  add = tf.compat.v1.scatter_nd_add(ref, indices, updates)
577  with tf.compat.v1.Session() as sess:
578    print sess.run(add)
579  ```
580
581  The resulting update to ref would look like this:
582
583      [1, 13, 3, 14, 14, 6, 7, 20]
584
585  See `tf.scatter_nd` for more details about how to make updates to
586  slices.
587
588  Args:
589    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
590      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
591      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
592      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
593    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
594      A tensor of indices into ref.
595    updates: A `Tensor`. Must have the same type as `ref`.
596      A tensor of updated values to add to ref.
597    use_locking: An optional `bool`. Defaults to `False`.
598      If True, the assignment will be protected by a lock;
599      otherwise the behavior is undefined, but may exhibit less contention.
600    name: A name for the operation (optional).
601
602  Returns:
603    A mutable `Tensor`. Has the same type as `ref`.
604  """
605  if ref.dtype._is_ref_dtype:
606    return gen_state_ops.scatter_nd_add(
607        ref, indices, updates, use_locking, name)
608  return ref._lazy_read(gen_state_ops.resource_scatter_nd_add(  # pylint: disable=protected-access
609      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
610      name=name))
611
612
613@tf_export(v1=["scatter_sub"])
614def scatter_sub(ref, indices, updates, use_locking=False, name=None):
615  r"""Subtracts sparse updates to a variable reference.
616
617  ```python
618      # Scalar indices
619      ref[indices, ...] -= updates[...]
620
621      # Vector indices (for each i)
622      ref[indices[i], ...] -= updates[i, ...]
623
624      # High rank indices (for each i, ..., j)
625      ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
626  ```
627
628  This operation outputs `ref` after the update is done.
629  This makes it easier to chain operations that need to use the reset value.
630
631  Duplicate entries are handled correctly: if multiple `indices` reference
632  the same location, their (negated) contributions add.
633
634  Requires `updates.shape = indices.shape + ref.shape[1:]` or
635  `updates.shape = []`.
636
637  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
638  <img style="width:100%"
639       src="https://www.tensorflow.org/images/ScatterSub.png" alt>
640  </div>
641
642  Args:
643    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
644      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
645      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
646      `uint32`, `uint64`. Should be from a `Variable` node.
647    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
648      A tensor of indices into the first dimension of `ref`.
649    updates: A `Tensor`. Must have the same type as `ref`.
650      A tensor of updated values to subtract from `ref`.
651    use_locking: An optional `bool`. Defaults to `False`.
652      If True, the subtraction will be protected by a lock;
653      otherwise the behavior is undefined, but may exhibit less contention.
654    name: A name for the operation (optional).
655
656  Returns:
657    A mutable `Tensor`. Has the same type as `ref`.
658  """
659  if ref.dtype._is_ref_dtype:
660    return gen_state_ops.scatter_sub(ref, indices, updates,
661                                     use_locking=use_locking, name=name)
662  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub(  # pylint: disable=protected-access
663      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
664      name=name))
665
666
667@tf_export(v1=["scatter_nd_sub"])
668def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None):
669  r"""Applies sparse subtraction to individual values or slices in a Variable.
670
671  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
672
673  `indices` must be integer tensor, containing indices into `ref`.
674  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
675
676  The innermost dimension of `indices` (with length `K`) corresponds to
677  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
678  dimension of `ref`.
679
680  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
681
682  ```
683  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
684  ```
685
686  For example, say we want to subtract 4 scattered elements from a rank-1 tensor
687  with 8 elements. In Python, that update would look like this:
688
689  ```python
690  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
691  indices = tf.constant([[4], [3], [1] ,[7]])
692  updates = tf.constant([9, 10, 11, 12])
693  op = tf.compat.v1.scatter_nd_sub(ref, indices, updates)
694  with tf.compat.v1.Session() as sess:
695    print sess.run(op)
696  ```
697
698  The resulting update to ref would look like this:
699
700      [1, -9, 3, -6, -6, 6, 7, -4]
701
702  See `tf.scatter_nd` for more details about how to make updates to
703  slices.
704
705  Args:
706    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
707      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
708      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
709      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
710    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
711      A tensor of indices into ref.
712    updates: A `Tensor`. Must have the same type as `ref`.
713      A tensor of updated values to add to ref.
714    use_locking: An optional `bool`. Defaults to `False`.
715      An optional bool. Defaults to True. If True, the assignment will
716      be protected by a lock; otherwise the behavior is undefined,
717      but may exhibit less contention.
718    name: A name for the operation (optional).
719
720  Returns:
721    A mutable `Tensor`. Has the same type as `ref`.
722  """
723  if ref.dtype._is_ref_dtype:
724    return gen_state_ops.scatter_nd_sub(
725        ref, indices, updates, use_locking, name)
726  return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub(  # pylint: disable=protected-access
727      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
728      name=name))
729
730
731@tf_export(v1=["scatter_mul"])
732def scatter_mul(ref, indices, updates, use_locking=False, name=None):
733  # pylint: disable=line-too-long
734  r"""Multiplies sparse updates into a variable reference.
735
736  This operation computes
737
738  ```python
739      # Scalar indices
740      ref[indices, ...] *= updates[...]
741
742      # Vector indices (for each i)
743      ref[indices[i], ...] *= updates[i, ...]
744
745      # High rank indices (for each i, ..., j)
746      ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
747  ```
748
749  This operation outputs `ref` after the update is done.
750  This makes it easier to chain operations that need to use the reset value.
751
752  Duplicate entries are handled correctly: if multiple `indices` reference
753  the same location, their contributions multiply.
754
755  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
756  []`.
757
758  Args:
759    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
760      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
761      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
762      `uint32`, `uint64`. Should be from a `Variable` node.
763    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
764      tensor of indices into the first dimension of `ref`.
765    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
766      values to multiply to `ref`.
767    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
768      will be protected by a lock; otherwise the behavior is undefined, but may
769      exhibit less contention.
770    name: A name for the operation (optional).
771
772  Returns:
773    A mutable `Tensor`. Has the same type as `ref`.
774  """
775  if ref.dtype._is_ref_dtype:
776    return gen_state_ops.scatter_mul(ref, indices, updates,
777                                     use_locking=use_locking, name=name)
778  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul(  # pylint: disable=protected-access
779      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
780      name=name))
781
782
783@tf_export(v1=["scatter_div"])
784def scatter_div(ref, indices, updates, use_locking=False, name=None):
785  # pylint: disable=line-too-long
786  r"""Divides a variable reference by sparse updates.
787
788  This operation computes
789
790  ```python
791      # Scalar indices
792      ref[indices, ...] /= updates[...]
793
794      # Vector indices (for each i)
795      ref[indices[i], ...] /= updates[i, ...]
796
797      # High rank indices (for each i, ..., j)
798      ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
799  ```
800
801  This operation outputs `ref` after the update is done.
802  This makes it easier to chain operations that need to use the reset value.
803
804  Duplicate entries are handled correctly: if multiple `indices` reference
805  the same location, their contributions divide.
806
807  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
808  []`.
809
810  Args:
811    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
812      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
813      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
814      `uint32`, `uint64`. Should be from a `Variable` node.
815    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
816      tensor of indices into the first dimension of `ref`.
817    updates: A `Tensor`. Must have the same type as `ref`. A tensor of values
818      that `ref` is divided by.
819    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
820      will be protected by a lock; otherwise the behavior is undefined, but may
821      exhibit less contention.
822    name: A name for the operation (optional).
823
824  Returns:
825    A mutable `Tensor`. Has the same type as `ref`.
826  """
827  if ref.dtype._is_ref_dtype:
828    return gen_state_ops.scatter_div(ref, indices, updates,
829                                     use_locking=use_locking, name=name)
830  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div(  # pylint: disable=protected-access
831      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
832      name=name))
833
834
835@tf_export(v1=["scatter_max"])
836def scatter_max(ref, indices, updates, use_locking=False, name=None):
837  # pylint: disable=line-too-long
838  r"""Reduces sparse updates into a variable reference using the `max` operation.
839
840  This operation computes
841
842      # Scalar indices
843      ref[indices, ...] = max(ref[indices, ...], updates[...])
844
845      # Vector indices (for each i)
846      ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
847
848      # High rank indices (for each i, ..., j)
849      ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...],
850      updates[i, ..., j, ...])
851
852  This operation outputs `ref` after the update is done.
853  This makes it easier to chain operations that need to use the reset value.
854
855  Duplicate entries are handled correctly: if multiple `indices` reference
856  the same location, their contributions combine.
857
858  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
859  []`.
860
861  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
862  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
863  alt>
864  </div>
865
866  Args:
867    ref: A mutable `Tensor`. Must be one of the following types: `half`,
868      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
869      `Variable` node.
870    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
871      tensor of indices into the first dimension of `ref`.
872    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
873      values to reduce into `ref`.
874    use_locking: An optional `bool`. Defaults to `False`. If True, the update
875      will be protected by a lock; otherwise the behavior is undefined, but may
876      exhibit less contention.
877    name: A name for the operation (optional).
878
879  Returns:
880    A mutable `Tensor`. Has the same type as `ref`.
881  """
882  if ref.dtype._is_ref_dtype:
883    return gen_state_ops.scatter_max(ref, indices, updates,
884                                     use_locking=use_locking, name=name)
885  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max(  # pylint: disable=protected-access
886      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
887      name=name))
888
889
890@tf_export(v1=["scatter_min"])
891def scatter_min(ref, indices, updates, use_locking=False, name=None):
892  # pylint: disable=line-too-long
893  r"""Reduces sparse updates into a variable reference using the `min` operation.
894
895  This operation computes
896
897      # Scalar indices
898      ref[indices, ...] = min(ref[indices, ...], updates[...])
899
900      # Vector indices (for each i)
901      ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
902
903      # High rank indices (for each i, ..., j)
904      ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...],
905      updates[i, ..., j, ...])
906
907  This operation outputs `ref` after the update is done.
908  This makes it easier to chain operations that need to use the reset value.
909
910  Duplicate entries are handled correctly: if multiple `indices` reference
911  the same location, their contributions combine.
912
913  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
914  []`.
915
916  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
917  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
918  alt>
919  </div>
920
921  Args:
922    ref: A mutable `Tensor`. Must be one of the following types: `half`,
923      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
924      `Variable` node.
925    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
926      tensor of indices into the first dimension of `ref`.
927    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
928      values to reduce into `ref`.
929    use_locking: An optional `bool`. Defaults to `False`. If True, the update
930      will be protected by a lock; otherwise the behavior is undefined, but may
931      exhibit less contention.
932    name: A name for the operation (optional).
933
934  Returns:
935    A mutable `Tensor`. Has the same type as `ref`.
936  """
937  if ref.dtype._is_ref_dtype:
938    return gen_state_ops.scatter_min(ref, indices, updates,
939                                     use_locking=use_locking, name=name)
940  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min(  # pylint: disable=protected-access
941      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
942      name=name))
943
944
945@tf_export(v1=["batch_scatter_update"])
946@deprecation.deprecated(
947    "2018-11-29", "Use the batch_scatter_update method of Variable instead.")
948def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
949  """Generalization of `tf.compat.v1.scatter_update` to axis different than 0.
950
951  Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
952  have a series of leading dimensions that are the same for all of them, and the
953  updates are performed on the last dimension of indices. In other words, the
954  dimensions should be the following:
955
956  `num_prefix_dims = indices.ndims - 1`
957  `batch_dim = num_prefix_dims + 1`
958  `updates.shape = indices.shape + var.shape[batch_dim:]`
959
960  where
961
962  `updates.shape[:num_prefix_dims]`
963  `== indices.shape[:num_prefix_dims]`
964  `== var.shape[:num_prefix_dims]`
965
966  And the operation performed can be expressed as:
967
968  `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
969
970  When indices is a 1D tensor, this operation is equivalent to
971  `tf.compat.v1.scatter_update`.
972
973  To avoid this operation there would be 2 alternatives:
974  1) Reshaping the variable by merging the first `ndims` dimensions. However,
975     this is not possible because `tf.reshape` returns a Tensor, which we
976     cannot use `tf.compat.v1.scatter_update` on.
977  2) Looping over the first `ndims` of the variable and using
978     `tf.compat.v1.scatter_update` on the subtensors that result of slicing the
979     first
980     dimension. This is a valid option for `ndims = 1`, but less efficient than
981     this implementation.
982
983  See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`.
984
985  Args:
986    ref: `Variable` to scatter onto.
987    indices: Tensor containing indices as described above.
988    updates: Tensor of updates to apply to `ref`.
989    use_locking: Boolean indicating whether to lock the writing operation.
990    name: Optional scope name string.
991
992  Returns:
993    Ref to `variable` after it has been modified.
994
995  Raises:
996    ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
997        not the same.
998  """
999  with ops.name_scope(name):
1000    indices = ops.convert_to_tensor(indices, name="indices")
1001    indices_shape = array_ops.shape(indices)
1002    indices_dimensions = indices.get_shape().ndims
1003
1004    if indices_dimensions is None:
1005      raise ValueError("batch_gather does not allow indices with unknown "
1006                       "shape.")
1007
1008    nd_indices = array_ops.expand_dims(indices, axis=-1)
1009    nd_indices_list = []
1010
1011    # Scatter ND requires indices to have an additional dimension, in which the
1012    # coordinates of the updated things are specified. For this to be adapted to
1013    # the scatter_update with several leading dimensions, we simply make use of
1014    # a tf.range for all the leading dimensions followed by concat of all the
1015    # coordinates we created with the original indices.
1016
1017    # For example if indices.shape = [2, 3, 4], we should generate the following
1018    # indices for tf.compat.v1.scatter_nd_update:
1019    # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
1020    # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
1021    # nd_indices[:, :, 2] = indices
1022    for dimension in range(indices_dimensions - 1):
1023      # In this loop we generate the following for the example (one for each
1024      # iteration).
1025      # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
1026      # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
1027      # This is done at every iteration with a tf.range over the size of the
1028      # i-th dimension and using broadcasting over the desired shape.
1029      dimension_size = indices_shape[dimension]
1030      shape_to_broadcast = [1] * (indices_dimensions + 1)
1031      shape_to_broadcast[dimension] = dimension_size
1032      dimension_range = array_ops.reshape(
1033          gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast)
1034      if dimension_range.dtype.base_dtype != nd_indices.dtype:
1035        dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype)
1036      nd_indices_list.append(
1037          dimension_range * array_ops.ones_like(nd_indices))
1038    # Add the original indices at the end, as described above, and concat.
1039    nd_indices_list.append(nd_indices)
1040    final_indices = array_ops.concat(nd_indices_list, axis=-1)
1041    return scatter_nd_update(
1042        ref, final_indices, updates, use_locking=use_locking)
1043