1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.resource_variable_ops."""
16import copy
17import gc
18import os
19import pickle
20import re
21
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.core.framework import full_type_pb2
26from tensorflow.core.framework import tensor_pb2
27from tensorflow.python.compat import compat as forward_compat
28from tensorflow.python.eager import backprop
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import cpp_shape_inference_pb2
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import indexed_slices
36from tensorflow.python.framework import memory_checker
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.framework import test_ops
41from tensorflow.python.framework import test_util
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import custom_gradient
45from tensorflow.python.ops import gradients_impl
46from tensorflow.python.ops import handle_data_util
47from tensorflow.python.ops import init_ops
48from tensorflow.python.ops import list_ops
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import resource_variable_ops
51from tensorflow.python.ops import state_ops
52from tensorflow.python.ops import variable_scope
53from tensorflow.python.ops import variables
54from tensorflow.python.platform import test
55from tensorflow.python.training import momentum
56from tensorflow.python.training import saver
57from tensorflow.python.training import training_util
58from tensorflow.python.util import compat
59
60
61def _eager_safe_var_handle_op(*args, **kwargs):
62  # When running in eager mode the `shared_name` should be set to the
63  # `anonymous_name` to avoid spurious sharing issues. The runtime generates a
64  # unique name on our behalf when the reserved `anonymous_name` is used as the
65  # `shared_name`.
66  if context.executing_eagerly() and "shared_name" not in kwargs:
67    kwargs["shared_name"] = context.anonymous_name()
68  return resource_variable_ops.var_handle_op(*args, **kwargs)
69
70
71@test_util.with_eager_op_as_function
72@test_util.with_control_flow_v2
73class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
74                              parameterized.TestCase):
75
76  def tearDown(self):
77    gc.collect()
78    # This will only contain uncollectable garbage, i.e. reference cycles
79    # involving objects with __del__ defined.
80    self.assertEmpty(gc.garbage)
81    super(ResourceVariableOpsTest, self).tearDown()
82
83  @test_util.run_deprecated_v1
84  def testHandleDtypeShapeMatch(self):
85    with self.cached_session():
86      handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
87      with self.assertRaises(ValueError):
88        resource_variable_ops.assign_variable_op(
89            handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
90      with self.assertRaises(ValueError):
91        resource_variable_ops.assign_variable_op(handle,
92                                                 constant_op.constant(
93                                                     [0],
94                                                     dtype=dtypes.int32)).run()
95      resource_variable_ops.assign_variable_op(handle,
96                                               constant_op.constant(
97                                                   0,
98                                                   dtype=dtypes.int32)).run()
99
100  @test_util.run_gpu_only
101  def testGPUInt64(self):
102    with context.eager_mode(), context.device("gpu:0"):
103      v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int64)
104      self.assertAllEqual(1, v.numpy())
105
106  @test_util.run_gpu_only
107  def testGPUBfloat16(self):
108    with context.eager_mode(), ops.device("gpu:0"):
109      v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.bfloat16)
110      self.assertEqual("/job:localhost/replica:0/task:0/device:GPU:0",
111                       v.device)
112      self.assertAllEqual(1, v.numpy())
113
114  def testEagerNameNotIdentity(self):
115    with context.eager_mode():
116      v0 = resource_variable_ops.ResourceVariable(1.0, name="a")
117      v1 = resource_variable_ops.ResourceVariable(2.0, name="a")
118      self.assertAllEqual(v0.numpy(), 1.0)
119      self.assertAllEqual(v1.numpy(), 2.0)
120
121  def testEagerNameNotNeeded(self):
122    with context.eager_mode():
123      v0 = resource_variable_ops.ResourceVariable(1.0)
124      self.assertAllEqual(v0.numpy(), 1.0)
125
126  def testReadVariableDtypeMismatchEager(self):
127    with context.eager_mode():
128      handle = _eager_safe_var_handle_op(
129          dtype=dtypes.int32, shape=[1], name="foo")
130      resource_variable_ops.assign_variable_op(handle, 1)
131      # The error message varies depending on whether it is being raised
132      # by the kernel or shape inference. The shape inference code path can
133      # be reached when running in eager op as function mode where each op
134      # is wrapped in a tf.function.
135      with self.assertRaisesRegex(
136          errors.InvalidArgumentError,
137          r"Trying to read variable with wrong dtype. "
138          r"Expected (float|int32) got (int32|float)"):
139        _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
140
141  def testEagerInitializedValue(self):
142    with context.eager_mode():
143      variable = resource_variable_ops.ResourceVariable(1.0, name="eager-init")
144      self.assertAllEqual(variable.numpy(), 1.0)
145      self.assertAllEqual(variable.initialized_value().numpy(), 1.0)
146
147  def testInitializeVariableUsingInitializedValue(self):
148    var1 = resource_variable_ops.ResourceVariable(1.0, name="var1")
149    var2 = resource_variable_ops.ResourceVariable(var1.initialized_value(),
150                                                  name="var2")
151    self.assertAllEqual(var2.initialized_value(), 1.0)
152
153  def testEagerBool(self):
154    with context.eager_mode():
155      v = resource_variable_ops.ResourceVariable(False, name="bool_test")
156      self.assertAllEqual(bool(v), False)
157
158  def testEagerDeepCopy(self):
159    with context.eager_mode():
160      init_value = np.ones((4, 4, 4))
161      variable = resource_variable_ops.ResourceVariable(
162          init_value,
163          name="init",
164          synchronization=variables.VariableSynchronization.ON_READ,
165          aggregation=variables.VariableAggregation.SUM)
166
167      copied_variable = copy.deepcopy(variable)
168      self.assertEqual(variable.name, copied_variable.name)
169      self.assertEqual(variable.shape, copied_variable.shape)
170      self.assertEqual(variable.device, copied_variable.device)
171      self.assertEqual(variable.synchronization,
172                       copied_variable.synchronization)
173      self.assertEqual(variable.aggregation, copied_variable.aggregation)
174
175      # The copied variable should have the same value as the original.
176      self.assertAllEqual(variable.numpy(), copied_variable.numpy())
177
178      # Updates to the copy should not be reflected in the original.
179      copied_variable.assign(4 * np.ones((4, 4, 4)))
180      self.assertNotAllEqual(variable.numpy(), copied_variable.numpy())
181
182  @test_util.run_deprecated_v1
183  def testGraphDeepCopy(self):
184    with self.cached_session():
185      init_value = np.ones((4, 4, 4))
186      variable = resource_variable_ops.ResourceVariable(init_value,
187                                                        name="init")
188      with self.assertRaises(NotImplementedError):
189        copy.deepcopy(variable)
190
191  @test_util.run_in_graph_and_eager_modes
192  def testStridedSliceAssign(self):
193    v = resource_variable_ops.ResourceVariable([1.0, 2.0])
194    self.evaluate(variables.global_variables_initializer())
195    self.evaluate(v[0].assign(2.0))
196    self.assertAllEqual(self.evaluate(v), [2.0, 2.0])
197
198  @test_util.run_in_graph_and_eager_modes
199  def testVariableShape(self):
200    v = resource_variable_ops.ResourceVariable([1., 1.])
201    vshape = resource_variable_ops.variable_shape(v.handle)
202    self.assertAllEqual(
203        tensor_util.constant_value(vshape),
204        [2])
205    if not context.executing_eagerly():
206      self.assertEqual("Const", vshape.op.type)
207
208  @test_util.run_deprecated_v1
209  def testDifferentAssignGraph(self):
210    with ops.Graph().as_default():
211      v = resource_variable_ops.ResourceVariable(1.0)
212    ops.reset_default_graph()
213    v.assign(2.0)  # Note: this fails if we run convert_to_tensor on not the
214    # variable graph.
215
216  @test_util.run_deprecated_v1
217  def testFetchHandle(self):
218    with self.cached_session():
219      handle = _eager_safe_var_handle_op(
220          dtype=dtypes.int32, shape=[1], name="foo")
221      self.assertNotEmpty(self.evaluate(handle))
222
223  @test_util.run_deprecated_v1
224  def testCachedValueReadBeforeWrite(self):
225    with self.cached_session() as sess:
226      v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0")
227      self.evaluate(v.initializer)
228      value, _ = sess.run([v, v.assign_add(1.0)])
229      self.assertAllEqual(value, 0.0)
230
231  def testAssignVariableDtypeMismatchEager(self):
232    with context.eager_mode():
233      handle = _eager_safe_var_handle_op(
234          dtype=dtypes.int32, shape=[1], name="foo")
235      resource_variable_ops.assign_variable_op(
236          handle, constant_op.constant([1]))
237      # The error message varies depending on whether it is being raised
238      # by the kernel or shape inference. The shape inference code path can
239      # be reached when running in eager op as function mode where each op
240      # is wrapped in a tf.function.
241      with self.assertRaisesRegex(
242          errors.InvalidArgumentError, r"Trying to .* variable with wrong "
243          r"dtype. Expected int32 got float"):
244        resource_variable_ops.assign_variable_op(
245            handle, constant_op.constant([1.], dtype=dtypes.float32))
246
247  def testRepr(self):
248    with context.eager_mode():
249      v = resource_variable_ops.ResourceVariable(1)
250      text = "%r" % v
251      self.assertEqual(
252          "<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=1>", text)
253
254  def testReprUnavailable(self):
255    with context.eager_mode():
256      v = resource_variable_ops.ResourceVariable(1)
257
258      # Monkey-patch this variable to not have an available value
259      def broken_read():
260        raise ValueError("This doesn't work")
261
262      v.read_value = broken_read
263      text = "%r" % v
264      self.assertEqual("<tf.Variable 'Variable:0' shape=() dtype=int32,"
265                       " numpy=<unavailable>>", text)
266
267  def testFormatResourceHandle(self):
268    with context.eager_mode():
269      handle = _eager_safe_var_handle_op(
270          dtype=dtypes.int32, shape=[1], name="foo")
271      self.assertIn("<ResourceHandle", str(handle))
272      self.assertIn("<ResourceHandle", repr(handle))
273
274  @test_util.run_in_graph_and_eager_modes
275  def testDtypeSurvivesIdentity(self):
276    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
277    id_handle = array_ops.identity(handle)
278    self.evaluate(resource_variable_ops.assign_variable_op(
279        id_handle, constant_op.constant(0, dtype=dtypes.int32)))
280
281  def testUnreadOpName(self):
282    v = resource_variable_ops.ResourceVariable(1.0)
283    self.assertNotEqual(v.name, v.assign_add(1.0).name)
284
285  @test_util.run_in_graph_and_eager_modes
286  def testCreateRead(self):
287    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
288    self.evaluate(resource_variable_ops.assign_variable_op(
289        handle, constant_op.constant(1, dtype=dtypes.int32)))
290    value = self.evaluate(
291        resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
292    self.assertAllEqual(1, value)
293
294  @test_util.run_in_graph_and_eager_modes
295  def testManyAssigns(self):
296    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
297    create = resource_variable_ops.assign_variable_op(
298        handle, constant_op.constant(1, dtype=dtypes.int32))
299    with ops.control_dependencies([create]):
300      first_read = resource_variable_ops.read_variable_op(
301          handle, dtype=dtypes.int32)
302    with ops.control_dependencies([first_read]):
303      write = resource_variable_ops.assign_variable_op(
304          handle, constant_op.constant(2, dtype=dtypes.int32))
305    with ops.control_dependencies([write]):
306      second_read = resource_variable_ops.read_variable_op(
307          handle, dtype=dtypes.int32)
308    f, s = self.evaluate([first_read, second_read])
309    self.assertEqual(f, 1)
310    self.assertEqual(s, 2)
311
312  @test_util.run_in_graph_and_eager_modes
313  def testAssignAdd(self):
314    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
315    self.evaluate(resource_variable_ops.assign_variable_op(
316        handle, constant_op.constant(1, dtype=dtypes.int32)))
317    self.evaluate(resource_variable_ops.assign_add_variable_op(
318        handle, constant_op.constant(1, dtype=dtypes.int32)))
319    read = self.evaluate(
320        resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
321    self.assertEqual(read, 2)
322
323  @test_util.run_in_graph_and_eager_modes
324  def testScatterAdd(self):
325    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
326    self.evaluate(
327        resource_variable_ops.assign_variable_op(
328            handle, constant_op.constant([[1]], dtype=dtypes.int32)))
329    self.evaluate(
330        resource_variable_ops.resource_scatter_add(
331            handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
332    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
333    self.assertEqual(self.evaluate(read), [[3]])
334
335  @test_util.run_in_graph_and_eager_modes
336  def testGradientGatherNd(self):
337    v = resource_variable_ops.ResourceVariable(
338        np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
339
340    with backprop.GradientTape() as tape:
341      l = array_ops.gather_nd(v, [[1, 1]])
342      l = math_ops.reduce_sum(l)
343
344    grads = tape.gradient(l, v)
345    self.evaluate(variables.global_variables_initializer())
346    self.assertAllEqual(self.evaluate(grads), [[0., 0.], [0., 1.]])
347
348  @test_util.run_deprecated_v1
349  def testDefaultGradientDtype(self):
350    v = resource_variable_ops.ResourceVariable(
351        np.random.uniform(size=[2, 2]), dtype=dtypes.float64)
352
353    c = constant_op.constant(1.)
354    identity = array_ops.identity_n([c, v.handle])
355    # TODO(b/137403775): Remove this.
356    handle_data_util.copy_handle_data(v.handle, identity[1])
357
358    g = gradients_impl.gradients(identity[0], [c, v.handle])
359    self.assertEqual(g[1].dtype, dtypes.float64)
360    self.evaluate(variables.global_variables_initializer())
361    self.assertAllEqual(g[1], [[0., 0.], [0., 0.]])
362
363  @test_util.run_deprecated_v1
364  def testUnconnectedGradientZeros(self):
365    b = resource_variable_ops.ResourceVariable(initial_value=[[3., 4.]])
366    c = constant_op.constant(0.)
367    g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0]
368    self.assertAllEqual(g.shape.as_list(), [1, 2])
369
370  @test_util.run_deprecated_v1
371  def testGradientCondInWhileLoop(self):
372    v = resource_variable_ops.ResourceVariable(initial_value=1.0)
373    def cond(i, unused_x):
374      return i < 1
375
376    def body(i, x):
377      def true():
378        return x + v
379      def false():
380        return 2.0 * v
381      return i + 1, control_flow_ops.cond(i > 0, true, false)
382
383    _, x = control_flow_ops.while_loop(cond, body, [0, 0.0])
384    # Computing gradients does not produce an exception:
385    g = gradients_impl.gradients(x, v)
386    self.evaluate(variables.global_variables_initializer())
387    # Only the false branch is taken so the gradient is 2.
388    self.assertAllEqual(g[0], 2.0)
389
390  @test_util.run_in_graph_and_eager_modes
391  def testGradientGatherNdIndexedSlices(self):
392    v = resource_variable_ops.ResourceVariable(
393        np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
394
395    with backprop.GradientTape() as tape:
396      l = array_ops.gather_nd(v, [[1], [1]])
397      l = math_ops.reduce_sum(l)
398
399    grads = tape.gradient(l, v)
400    self.evaluate(variables.global_variables_initializer())
401    self.assertAllEqual(self.evaluate(grads.values), [[1., 1.], [1., 1.]])
402
403  @test_util.run_in_graph_and_eager_modes
404  def testScatterSub(self):
405    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
406    self.evaluate(
407        resource_variable_ops.assign_variable_op(
408            handle, constant_op.constant([[1]], dtype=dtypes.int32)))
409    self.evaluate(
410        resource_variable_ops.resource_scatter_sub(
411            handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
412    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
413    self.assertEqual(self.evaluate(read), [[-1]])
414
415  @test_util.run_in_graph_and_eager_modes
416  def testScatterMul(self):
417    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
418    self.evaluate(
419        resource_variable_ops.assign_variable_op(
420            handle, constant_op.constant([[1]], dtype=dtypes.int32)))
421    self.evaluate(
422        resource_variable_ops.resource_scatter_mul(
423            handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
424    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
425    self.assertEqual(self.evaluate(read), [[5]])
426
427  def testEagerPickle(self):
428    with context.eager_mode():
429      tmp_dir = self.get_temp_dir()
430      fname = os.path.join(tmp_dir, "var.pickle")
431      with open(fname, "wb") as f:
432        v = resource_variable_ops.ResourceVariable(
433            10.0,
434            dtype=dtypes.float16,
435            name="v")
436        pickle.dump(v, f)
437
438      with open(fname, "rb") as f:
439        new_v = pickle.load(f)
440        self.assertEqual(new_v.name, v.name)
441        self.assertEqual(new_v.shape, v.shape)
442        self.assertEqual(new_v.dtype, v.dtype)
443        self.assertEqual(new_v.trainable, v.trainable)
444        self.assertAllEqual(new_v.numpy(), v.numpy())
445
446  @test_util.run_in_graph_and_eager_modes
447  def testScatterDiv(self):
448    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
449    self.evaluate(
450        resource_variable_ops.assign_variable_op(
451            handle, constant_op.constant([[6]], dtype=dtypes.int32)))
452    self.evaluate(
453        resource_variable_ops.resource_scatter_div(
454            handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
455    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
456    self.assertEqual(self.evaluate(read), [[2]])
457
458  def testUseResource(self):
459    v = variables.VariableV1(1.0, use_resource=True)
460    self.assertIsInstance(v, resource_variable_ops.ResourceVariable)
461
462  def testEagerNoUseResource(self):
463    with context.eager_mode():
464      v = variables.Variable(1.0)
465      self.assertIsInstance(v, resource_variable_ops.ResourceVariable)
466
467  @test_util.run_in_graph_and_eager_modes
468  def testScatterMin(self):
469    with ops.device("cpu:0"):
470      handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
471      self.evaluate(
472          resource_variable_ops.assign_variable_op(handle,
473                                                   constant_op.constant(
474                                                       [[6]],
475                                                       dtype=dtypes.int32)))
476      self.evaluate(
477          resource_variable_ops.resource_scatter_min(handle, [0],
478                                                     constant_op.constant(
479                                                         [[3]],
480                                                         dtype=dtypes.int32)))
481      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
482      self.assertEqual(self.evaluate(read), [[3]])
483
484  def testMetagraph(self):
485    with ops.Graph().as_default():
486      with variable_scope.variable_scope("foo", use_resource=True):
487        a = variable_scope.get_variable("a", initializer=10.0)
488
489      momentum.MomentumOptimizer(
490          learning_rate=0.001, momentum=0.1).minimize(
491              a,
492              colocate_gradients_with_ops=True,
493              global_step=training_util.get_or_create_global_step())
494
495      graph = ops.get_default_graph()
496      meta_graph_def = saver.export_meta_graph(graph=graph)
497
498    with ops.Graph().as_default():
499      saver.import_meta_graph(meta_graph_def, import_scope="")
500      meta_graph_two = saver.export_meta_graph(graph=graph)
501    self.assertEqual(meta_graph_def, meta_graph_two)
502
503  @test_util.run_in_graph_and_eager_modes
504  def testScatterMax(self):
505    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
506    self.evaluate(
507        resource_variable_ops.assign_variable_op(
508            handle, constant_op.constant([[6]], dtype=dtypes.int32)))
509    self.evaluate(
510        resource_variable_ops.resource_scatter_max(
511            handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
512    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
513    self.assertEqual(self.evaluate(read), [[6]])
514
515  @test_util.run_in_graph_and_eager_modes
516  def testScatterAddScalar(self):
517    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
518    self.evaluate(
519        resource_variable_ops.assign_variable_op(
520            handle, constant_op.constant([[1]], dtype=dtypes.int32)))
521    self.evaluate(
522        resource_variable_ops.resource_scatter_add(
523            handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
524    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
525    self.assertEqual(self.evaluate(read), [[3]])
526
527  @test_util.run_in_graph_and_eager_modes
528  def testScatterSubScalar(self):
529    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
530    self.evaluate(
531        resource_variable_ops.assign_variable_op(
532            handle, constant_op.constant([[1]], dtype=dtypes.int32)))
533    self.evaluate(
534        resource_variable_ops.resource_scatter_sub(
535            handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
536    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
537    self.assertEqual(self.evaluate(read), [[-1]])
538
539  @test_util.run_in_graph_and_eager_modes
540  def testScatterMulScalar(self):
541    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
542    self.evaluate(
543        resource_variable_ops.assign_variable_op(
544            handle, constant_op.constant([[1]], dtype=dtypes.int32)))
545    self.evaluate(
546        resource_variable_ops.resource_scatter_mul(
547            handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
548    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
549    self.assertEqual(self.evaluate(read), [[5]])
550
551  @test_util.run_in_graph_and_eager_modes
552  def testScatterDivScalar(self):
553    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
554    self.evaluate(
555        resource_variable_ops.assign_variable_op(
556            handle, constant_op.constant([[6]], dtype=dtypes.int32)))
557    self.evaluate(
558        resource_variable_ops.resource_scatter_div(
559            handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
560    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
561    self.assertEqual(self.evaluate(read), [[2]])
562
563  @test_util.run_in_graph_and_eager_modes
564  def testScatterMinScalar(self):
565    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
566    self.evaluate(
567        resource_variable_ops.assign_variable_op(
568            handle, constant_op.constant([[6]], dtype=dtypes.int32)))
569    self.evaluate(
570        resource_variable_ops.resource_scatter_min(
571            handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
572    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
573    self.assertEqual(self.evaluate(read), [[3]])
574
575  @test_util.run_in_graph_and_eager_modes
576  def testScatterMaxScalar(self):
577    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
578    self.evaluate(
579        resource_variable_ops.assign_variable_op(
580            handle, constant_op.constant([[6]], dtype=dtypes.int32)))
581    self.evaluate(
582        resource_variable_ops.resource_scatter_max(
583            handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
584    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
585    self.assertEqual(self.evaluate(read), [[6]])
586
587  @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64)
588  @test_util.run_in_graph_and_eager_modes
589  def testScatterAddVariableMethod(self, dtype):
590    v = resource_variable_ops.ResourceVariable([0.0, 1.5],
591                                               name="add",
592                                               dtype=dtype)
593    self.evaluate(variables.global_variables_initializer())
594    self.evaluate(
595        v.scatter_add(
596            indexed_slices.IndexedSlices(
597                indices=[1], values=constant_op.constant([2.5], dtype=dtype))))
598    self.assertAllCloseAccordingToType([0.0, 4.0], self.evaluate(v))
599
600  @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64)
601  @test_util.run_in_graph_and_eager_modes
602  def testScatterSubVariableMethod(self, dtype):
603    v = resource_variable_ops.ResourceVariable([0.0, 2.5],
604                                               name="sub",
605                                               dtype=dtype)
606    self.evaluate(variables.global_variables_initializer())
607    self.evaluate(
608        v.scatter_sub(
609            indexed_slices.IndexedSlices(
610                indices=[1], values=constant_op.constant([1.5], dtype=dtype))))
611    self.assertAllCloseAccordingToType([0.0, 1.0], self.evaluate(v))
612
613  @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64)
614  @test_util.run_in_graph_and_eager_modes
615  def testScatterMaxVariableMethod(self, dtype):
616    v = resource_variable_ops.ResourceVariable([0.0, 4.0],
617                                               name="max1",
618                                               dtype=dtype)
619    self.evaluate(variables.global_variables_initializer())
620    self.evaluate(
621        v.scatter_max(
622            indexed_slices.IndexedSlices(
623                indices=[1], values=constant_op.constant([5.0], dtype=dtype))))
624    self.assertAllCloseAccordingToType([0.0, 5.0], self.evaluate(v))
625
626    v = resource_variable_ops.ResourceVariable([0.0, 3.5],
627                                               name="max2",
628                                               dtype=dtype)
629    self.evaluate(variables.global_variables_initializer())
630    self.evaluate(
631        v.scatter_max(
632            indexed_slices.IndexedSlices(
633                indices=[1], values=constant_op.constant([2.0], dtype=dtype))))
634    self.assertAllCloseAccordingToType([0.0, 3.5], self.evaluate(v))
635
636  @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64)
637  @test_util.run_in_graph_and_eager_modes
638  def testScatterMinVariableMethod(self, dtype):
639    v = resource_variable_ops.ResourceVariable([0.0, 4.0],
640                                               name="min1",
641                                               dtype=dtype)
642    self.evaluate(variables.global_variables_initializer())
643    self.evaluate(
644        v.scatter_min(
645            indexed_slices.IndexedSlices(
646                indices=[1], values=constant_op.constant([5.0], dtype=dtype))))
647    self.assertAllCloseAccordingToType([0.0, 4.0], self.evaluate(v))
648
649    v = resource_variable_ops.ResourceVariable([0.0, 3.5],
650                                               name="min2",
651                                               dtype=dtype)
652    self.evaluate(variables.global_variables_initializer())
653    self.evaluate(
654        v.scatter_min(
655            indexed_slices.IndexedSlices(
656                indices=[1], values=constant_op.constant([2.0], dtype=dtype))))
657    self.assertAllCloseAccordingToType([0.0, 2.0], self.evaluate(v))
658
659  @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64)
660  @test_util.run_in_graph_and_eager_modes
661  def testScatterMulVariableMethod(self, dtype):
662    v = resource_variable_ops.ResourceVariable([0.0, 4.0],
663                                               name="mul",
664                                               dtype=dtype)
665    self.evaluate(variables.global_variables_initializer())
666    self.evaluate(
667        v.scatter_mul(
668            indexed_slices.IndexedSlices(
669                indices=[1], values=constant_op.constant([3.0], dtype=dtype))))
670    self.assertAllCloseAccordingToType([0.0, 12.0], self.evaluate(v))
671
672  @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64)
673  @test_util.run_in_graph_and_eager_modes
674  def testScatterDivVariableMethod(self, dtype):
675    v = resource_variable_ops.ResourceVariable([0.0, 6.0],
676                                               name="div",
677                                               dtype=dtype)
678    self.evaluate(variables.global_variables_initializer())
679    self.evaluate(
680        v.scatter_div(
681            indexed_slices.IndexedSlices(
682                indices=[1], values=constant_op.constant([2.0], dtype=dtype))))
683    self.assertAllCloseAccordingToType([0.0, 3.0], self.evaluate(v))
684
685  @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64)
686  @test_util.run_in_graph_and_eager_modes
687  def testScatterUpdateVariableMethod(self, dtype):
688    v = resource_variable_ops.ResourceVariable([0.0, 6.0],
689                                               name="update",
690                                               dtype=dtype)
691    self.evaluate(variables.global_variables_initializer())
692    self.evaluate(
693        v.scatter_update(
694            indexed_slices.IndexedSlices(
695                indices=[1], values=constant_op.constant([3.0], dtype=dtype))))
696    self.assertAllCloseAccordingToType([0.0, 3.0], self.evaluate(v))
697
698  @test_util.run_deprecated_v1
699  def testScatterUpdateString(self):
700    handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
701    self.evaluate(resource_variable_ops.assign_variable_op(
702        handle, constant_op.constant([["a"]], dtype=dtypes.string)))
703    self.evaluate(resource_variable_ops.resource_scatter_update(
704        handle, [0], constant_op.constant([["b"]], dtype=dtypes.string)))
705    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
706    self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
707                     compat.as_bytes("b"))
708
709  @test_util.run_deprecated_v1
710  def testScatterUpdateStringScalar(self):
711    handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
712    self.evaluate(
713        resource_variable_ops.assign_variable_op(handle,
714                                                 constant_op.constant(
715                                                     [["a"]],
716                                                     dtype=dtypes.string)))
717    self.evaluate(
718        resource_variable_ops.resource_scatter_update(handle, [0],
719                                                      constant_op.constant(
720                                                          "b",
721                                                          dtype=dtypes.string)))
722    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
723    self.assertEqual(
724        compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
725
726  # TODO(alive): get this to work in Eager mode.
727  def testGPU(self):
728    with test_util.use_gpu():
729      abc = variable_scope.get_variable(
730          "abc",
731          shape=[1],
732          initializer=init_ops.ones_initializer(),
733          use_resource=True)
734
735      self.evaluate(variables.global_variables_initializer())
736      self.assertEqual(
737          self.evaluate(
738              resource_variable_ops.var_is_initialized_op(abc.handle)),
739          True)
740
741  def testScatterBool(self):
742    with context.eager_mode():
743      ref = resource_variable_ops.ResourceVariable(
744          [False, True, False], trainable=False)
745      indices = math_ops.range(3)
746      updates = constant_op.constant([True, True, True])
747      state_ops.scatter_update(ref, indices, updates)
748      self.assertAllEqual(ref.read_value(), [True, True, True])
749
750  @test_util.run_in_graph_and_eager_modes
751  def testConstraintArg(self):
752    constraint = lambda x: x
753    v = resource_variable_ops.ResourceVariable(
754        initial_value=lambda: 1, constraint=constraint, name="var0")
755    self.assertEqual(v.constraint, constraint)
756
757    constraint = 0
758    with self.assertRaises(ValueError):
759      v = resource_variable_ops.ResourceVariable(
760          initial_value=lambda: 1, constraint=constraint, name="var1")
761
762  # TODO(alive): how should this work in Eager mode?
763  @test_util.run_deprecated_v1
764  def testInitFn(self):
765    with self.cached_session():
766      v = resource_variable_ops.ResourceVariable(
767          initial_value=lambda: 1, dtype=dtypes.float32)
768      self.assertEqual(v.handle.op.colocation_groups(),
769                       v.initializer.inputs[1].op.colocation_groups())
770
771  def testCountUpTo(self):
772    with context.eager_mode():
773      v = resource_variable_ops.ResourceVariable(0, name="upto")
774      self.assertAllEqual(v.count_up_to(1), 0)
775      with self.assertRaises(errors.OutOfRangeError):
776        v.count_up_to(1)
777
778  def testCountUpToFunction(self):
779    with context.eager_mode():
780      v = resource_variable_ops.ResourceVariable(0, name="upto")
781      self.assertAllEqual(state_ops.count_up_to(v, 1), 0)
782      with self.assertRaises(errors.OutOfRangeError):
783        state_ops.count_up_to(v, 1)
784
785  @test_util.run_in_graph_and_eager_modes
786  def testInitFnDtype(self):
787    v = resource_variable_ops.ResourceVariable(
788        initial_value=lambda: 1, dtype=dtypes.float32, name="var0")
789    self.assertEqual(dtypes.float32, v.value().dtype)
790
791  @test_util.run_in_graph_and_eager_modes
792  def testInitFnNoDtype(self):
793    v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
794                                               name="var2")
795    self.assertEqual(dtypes.int32, v.value().dtype)
796
797  @test_util.run_in_graph_and_eager_modes
798  def testInitializeAllVariables(self):
799    v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32,
800                                               name="var0")
801    self.evaluate(variables.global_variables_initializer())
802    self.assertEqual(1.0, self.evaluate(v.value()))
803
804  @test_util.run_in_graph_and_eager_modes
805  def testOperatorOverload(self):
806    v = resource_variable_ops.ResourceVariable(1.0, name="var0")
807    self.evaluate(variables.global_variables_initializer())
808    self.assertEqual(2.0, self.evaluate(v + v))
809
810  @test_util.run_in_graph_and_eager_modes
811  def testAssignMethod(self):
812    v = resource_variable_ops.ResourceVariable(1.0, name="var0")
813    self.evaluate(variables.global_variables_initializer())
814    self.evaluate(v.assign(2.0))
815    self.assertEqual(2.0, self.evaluate(v.value()))
816
817    # Tests for the 'read_value' argument:
818    assign_with_read = v.assign(3.0, read_value=True)
819    self.assertEqual(3.0, self.evaluate(assign_with_read))
820    assign_without_read = v.assign(4.0, read_value=False)
821    if context.executing_eagerly():
822      self.assertIsNone(assign_without_read)
823    else:
824      self.assertIsInstance(assign_without_read, ops.Operation)
825    self.evaluate(assign_without_read)
826    self.assertEqual(4.0, self.evaluate(v.value()))
827
828  def testAssignRuntimeShapeCheck(self):
829    with forward_compat.forward_compatibility_horizon(2022, 3, 30):
830      v = resource_variable_ops.ResourceVariable([1.0, 1.0], name="var0")
831
832      @def_function.function
833      def f(shape):
834        t = array_ops.zeros(shape)
835        v.assign(t)
836
837      with self.assertRaises((errors.InvalidArgumentError, ValueError)):
838        f(constant_op.constant([3]))
839
840  @test_util.run_in_graph_and_eager_modes
841  def testLoad(self):
842    v = resource_variable_ops.ResourceVariable(1.0, name="var0")
843    self.evaluate(variables.global_variables_initializer())
844    v.load(2.0)
845    self.assertEqual(2.0, self.evaluate(v.value()))
846
847  def testShapePassedToGradient(self):
848    with ops.Graph().as_default():
849      @custom_gradient.custom_gradient
850      def differentiable_scatter_update(handle, indices, values):
851        with ops.control_dependencies([
852            resource_variable_ops.resource_scatter_update(
853                handle, indices, values)]):
854          new_handle = array_ops.identity(handle)
855
856        def grad(dresult):
857          self.assertIsNotNone(
858              tensor_util.constant_value(dresult.dense_shape))
859          return [dresult, None, None]
860
861        return new_handle, grad
862
863      var = variable_scope.get_variable(
864          "foo", shape=[20], initializer=init_ops.zeros_initializer,
865          dtype=dtypes.float64, use_resource=True)
866
867      indices = math_ops.range(10)
868      updates = math_ops.range(9, -1, -1, dtype=dtypes.float64)
869      new_handle = differentiable_scatter_update(var.handle, indices, updates)
870      gathered = resource_variable_ops.resource_gather(
871          new_handle, indices, dtype=var.dtype)
872      gradients_impl.gradients([gathered], [updates])
873
874  def testToFromProtoCachedValue(self):
875    with ops.Graph().as_default():
876      v_def = resource_variable_ops.ResourceVariable(
877          initial_value=constant_op.constant(3.0)).to_proto()
878      v_prime = resource_variable_ops.ResourceVariable(variable_def=v_def)
879      self.assertIsNone(getattr(v_prime, "_cached_value", None))
880
881      other_v_def = resource_variable_ops.ResourceVariable(
882          caching_device="cpu:0",
883          initial_value=constant_op.constant(3.0)).to_proto()
884      other_v_prime = resource_variable_ops.ResourceVariable(
885          variable_def=other_v_def)
886      self.assertIsNotNone(other_v_prime._cached_value)
887
888  def testVariableDefInitializedInstances(self):
889    with ops.Graph().as_default(), self.cached_session():
890      v_def = resource_variable_ops.ResourceVariable(
891          initial_value=constant_op.constant(3.0)).to_proto()
892
893    with ops.Graph().as_default(), self.cached_session():
894      # v describes a VariableDef-based variable without an initial value.
895      v = resource_variable_ops.ResourceVariable(variable_def=v_def)
896      self.assertEqual(3.0, self.evaluate(v.initialized_value()))
897
898      # initialized_value should not rerun the initializer_op if the variable
899      # has already been initialized elsewhere.
900      self.evaluate(v.assign(1.0))
901      self.assertEqual(1.0, v.initialized_value().eval())
902
903    v_def.ClearField("initial_value_name")
904    with ops.Graph().as_default(), self.cached_session():
905      # Restoring a legacy VariableDef proto that does not have
906      # initial_value_name set should still work.
907      v = resource_variable_ops.ResourceVariable(variable_def=v_def)
908      # We should also be able to re-export the variable to a new meta graph.
909      self.assertProtoEquals(v_def, v.to_proto())
910      # But attempts to use initialized_value will result in errors.
911      with self.assertRaises(ValueError):
912        self.evaluate(v.initialized_value())
913
914  def testTrainableInProto(self):
915    with ops.Graph().as_default():
916      non_trainable_variable = resource_variable_ops.ResourceVariable(
917          trainable=False,
918          initial_value=constant_op.constant(10.0))
919      self.assertEqual(
920          False,
921          resource_variable_ops.ResourceVariable(
922              variable_def=non_trainable_variable.to_proto())
923          .trainable)
924      trainable_variable = resource_variable_ops.ResourceVariable(
925          trainable=True,
926          initial_value=constant_op.constant(10.0))
927      self.assertEqual(
928          True,
929          resource_variable_ops.ResourceVariable(
930              variable_def=trainable_variable.to_proto())
931          .trainable)
932
933  @test_util.run_in_graph_and_eager_modes
934  def testSparseRead(self):
935    init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
936    v = resource_variable_ops.ResourceVariable(
937        constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
938    self.evaluate(variables.global_variables_initializer())
939
940    value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
941    self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
942
943  @test_util.run_in_graph_and_eager_modes
944  def testGatherNd(self):
945    init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
946    v = resource_variable_ops.ResourceVariable(
947        constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
948    self.evaluate(variables.global_variables_initializer())
949
950    value_op = v.gather_nd([[0, 0], [1, 2], [3, 3]])
951    self.assertAllEqual([3, 4], value_op.shape)
952    value = self.evaluate(value_op)
953    self.assertAllEqual([[0, 1, 2, 3], [24, 25, 26, 27], [60, 61, 62, 63]],
954                        value)
955
956    value_op = v.gather_nd([[0, 0, 0], [1, 2, 3], [3, 3, 3]])
957    self.assertAllEqual([3], value_op.shape)
958    value = self.evaluate(value_op)
959    self.assertAllEqual([0, 27, 63], value)
960
961  @test_util.run_deprecated_v1
962  def testToFromProto(self):
963    with self.cached_session():
964      v = resource_variable_ops.ResourceVariable(1.0)
965      self.evaluate(variables.global_variables_initializer())
966
967      w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
968      self.assertEqual(2, math_ops.add(w, 1).eval())
969
970      self.assertEqual(v._handle, w._handle)
971      self.assertEqual(v._graph_element, w._graph_element)
972
973  @test_util.run_in_graph_and_eager_modes
974  def testAssignAddMethod(self):
975    v = resource_variable_ops.ResourceVariable(1.0, name="var0")
976    self.evaluate(variables.global_variables_initializer())
977    self.evaluate(v.assign_add(1.0))
978    self.assertEqual(2.0, self.evaluate(v.value()))
979
980    # Tests for the 'read_value' argument:
981    assign_with_read = v.assign_add(1.0, read_value=True)
982    self.assertEqual(3.0, self.evaluate(assign_with_read))
983    assign_without_read = v.assign_add(1.0, read_value=False)
984    if context.executing_eagerly():
985      self.assertIsNone(assign_without_read)
986    else:
987      self.assertIsInstance(assign_without_read, ops.Operation)
988    self.evaluate(assign_without_read)
989    self.assertEqual(4.0, self.evaluate(v.value()))
990
991  @test_util.run_in_graph_and_eager_modes
992  def testAssignSubMethod(self):
993    v = resource_variable_ops.ResourceVariable(3.0, name="var0")
994    self.evaluate(variables.global_variables_initializer())
995    self.evaluate(v.assign_sub(1.0))
996    self.assertEqual(2.0, self.evaluate(v.value()))
997
998    # Tests for the 'read_value' argument:
999    assign_with_read = v.assign_sub(1.0, read_value=True)
1000    self.assertEqual(1.0, self.evaluate(assign_with_read))
1001    assign_without_read = v.assign_sub(1.0, read_value=False)
1002    if context.executing_eagerly():
1003      self.assertIsNone(assign_without_read)
1004    else:
1005      self.assertIsInstance(assign_without_read, ops.Operation)
1006    self.evaluate(assign_without_read)
1007    self.assertEqual(0.0, self.evaluate(v.value()))
1008
1009  @test_util.run_in_graph_and_eager_modes
1010  @test_util.run_v1_only("b/120545219")
1011  def testDestroyResource(self):
1012    v = resource_variable_ops.ResourceVariable(3.0, name="var0")
1013    self.evaluate(variables.global_variables_initializer())
1014    self.assertEqual(3.0, self.evaluate(v.value()))
1015    self.evaluate(resource_variable_ops.destroy_resource_op(v.handle))
1016    if context.executing_eagerly():
1017      # eager mode creates ref-counting variable handles unaffected by
1018      # DestroyResourceOp.
1019      self.assertEqual(3.0, self.evaluate(v.value()))
1020    else:
1021      with self.assertRaises(errors.FailedPreconditionError):
1022        self.evaluate(v.value())
1023    # Handle to a resource not actually created.
1024    handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
1025    # Should raise no exception
1026    self.evaluate(resource_variable_ops.destroy_resource_op(
1027        handle, ignore_lookup_error=True))
1028
1029  @test_util.run_deprecated_v1
1030  def testAssignDifferentShapes(self):
1031    with self.cached_session() as sess, variable_scope.variable_scope(
1032        "foo", use_resource=True):
1033      var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32)
1034      placeholder = array_ops.placeholder(dtypes.float32)
1035      assign = var.assign(placeholder)
1036      sess.run(
1037          [assign],
1038          feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)})
1039
1040  def testAssignDifferentShapesEagerNotAllowed(self):
1041    with context.eager_mode():
1042      with variable_scope.variable_scope("foo"):
1043        var = variable_scope.get_variable("x", shape=[1, 1],
1044                                          dtype=dtypes.float32)
1045        with self.assertRaisesRegex(ValueError,
1046                                    "shape.*and.*are incompatible"):
1047          assign = var.assign(np.zeros(shape=[2, 2]))
1048          self.evaluate(assign)
1049
1050  @test_util.disable_xla("XLA doesn't allow changing shape at assignment, as "
1051                         "dictated by tf2xla/xla_resource.cc:SetTypeAndShape")
1052  @test_util.run_in_graph_and_eager_modes
1053  def testAssignDifferentShapesAllowed(self):
1054    var = resource_variable_ops.ResourceVariable(
1055        initial_value=np.zeros(shape=[1, 1]),
1056        shape=tensor_shape.TensorShape(None))
1057    self.evaluate(variables.global_variables_initializer())
1058    self.assertAllEqual(np.zeros(shape=[1, 1]), var.read_value())
1059    self.evaluate(var.assign(np.zeros(shape=[2, 2])))
1060    self.assertAllEqual(np.zeros(shape=[2, 2]), var.read_value())
1061
1062  @test_util.run_in_graph_and_eager_modes
1063  def testAssignReturnsVariable(self):
1064    var = resource_variable_ops.ResourceVariable(1.)
1065    self.evaluate(variables.global_variables_initializer())
1066    assigned = var.assign(2.)
1067    self.assertIsInstance(assigned, resource_variable_ops.BaseResourceVariable)
1068    assigned = assigned.assign(3.)
1069    self.assertEqual(self.evaluate(assigned), 3.)
1070    self.assertEqual(self.evaluate(var), 3.)
1071
1072    self.assertEqual(self.evaluate(var.assign_add(1.).assign_add(1.)), 5)
1073    self.assertEqual(self.evaluate(var.assign_sub(1.).assign_sub(1.)), 3)
1074
1075    var = resource_variable_ops.ResourceVariable([1., 2.])
1076    self.evaluate(variables.global_variables_initializer())
1077    slices = indexed_slices.IndexedSlices(indices=[1], values=[2])
1078    def assert_eq(tensor, vals):
1079      self.assertAllEqual(self.evaluate(tensor), vals)
1080    assert_eq(var.scatter_add(slices).scatter_add(slices), [1., 6.])
1081    assert_eq(var.scatter_sub(slices).scatter_sub(slices), [1., 2.])
1082    slices2 = indexed_slices.IndexedSlices(indices=[0], values=[3])
1083    assert_eq(var.scatter_max(slices2).scatter_add(slices), [3., 4.])
1084    assert_eq(var.scatter_add(slices).scatter_min(slices), [3., 2.])
1085    assert_eq(var.scatter_mul(slices).scatter_mul(slices), [3., 8.])
1086    assert_eq(var.scatter_div(slices).scatter_div(slices), [3., 2.])
1087    assert_eq(
1088        var.scatter_nd_update([[1]], [4.]).scatter_nd_add([[0]], [2.])
1089        .scatter_nd_sub([[1]], [3]),
1090        [5., 1.])
1091    assert_eq(var, [5., 1.])
1092
1093    batch_var = resource_variable_ops.ResourceVariable(array_ops.ones((2, 2)))
1094    self.evaluate(variables.global_variables_initializer())
1095    batch_slices1 = indexed_slices.IndexedSlices(
1096        indices=[[1], [0]], values=[[2], [2]])
1097    batch_slices2 = indexed_slices.IndexedSlices(
1098        indices=[[1], [1]], values=[[3], [3]])
1099    assert_eq(
1100        batch_var.batch_scatter_update(batch_slices1)
1101        .batch_scatter_update(batch_slices2),
1102        [[1, 3], [2, 3]])
1103
1104  @test_util.run_in_graph_and_eager_modes
1105  def testInitValueWrongShape(self):
1106    with self.assertRaisesWithPredicateMatch(
1107        ValueError, r"not compatible with"):
1108      var = resource_variable_ops.ResourceVariable(
1109          initial_value=np.zeros(shape=[3]),
1110          shape=[4])
1111      self.evaluate(variables.global_variables_initializer())
1112      self.evaluate(var.read_value())
1113
1114  @test_util.run_deprecated_v1
1115  def testDtypeAfterFromProto(self):
1116    v = resource_variable_ops.ResourceVariable(2.0)
1117    w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
1118    self.assertIsInstance(w.dtype, dtypes.DType)
1119    self.assertEqual(v.dtype, w.dtype)
1120
1121  # TODO(alive): get caching to work in eager mode.
1122  @test_util.run_deprecated_v1
1123  def testCachingDevice(self):
1124    with ops.device("/job:server/task:1"):
1125      v = resource_variable_ops.ResourceVariable(
1126          2.0, caching_device="/job:localhost")
1127      self.assertEqual("/job:localhost", v.value().device)
1128      with self.assertRaises(ValueError):
1129        _ = v.value().op.get_attr("_class")
1130
1131    with ops.colocate_with(v.op):
1132      w = resource_variable_ops.ResourceVariable(
1133          2.0, caching_device="/job:localhost")
1134      self.assertEqual("/job:localhost", w.value().device)
1135      with self.assertRaises(ValueError):
1136        _ = w.value().op.get_attr("_class")
1137
1138  @test_util.run_deprecated_v1
1139  def testSharedName(self):
1140    with self.cached_session():
1141      v = resource_variable_ops.ResourceVariable(300.0, name="var4")
1142      self.evaluate(variables.global_variables_initializer())
1143
1144      w = _eager_safe_var_handle_op(
1145          dtype=v.dtype.base_dtype,
1146          shape=v.get_shape(),
1147          shared_name="var4",
1148          # Needed in Eager since we get a unique container name by default.
1149          container=ops.get_default_graph()._container)
1150      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
1151      self.assertEqual(300.0, self.evaluate(w_read))
1152
1153      x = _eager_safe_var_handle_op(
1154          dtype=v.dtype.base_dtype,
1155          shape=v.get_shape(),
1156          shared_name="var5",
1157          container=ops.get_default_graph()._container)
1158      with self.assertRaisesOpError(
1159          "(Resource .*/var5/.* does not exist|uninitialized)"):
1160        resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
1161
1162  @test_util.run_deprecated_v1
1163  def testSharedNameWithNamescope(self):
1164    with self.cached_session():
1165      with ops.name_scope("foo"):
1166        v = resource_variable_ops.ResourceVariable(300.0, name="var6")
1167        self.assertEqual("foo/var6", v._shared_name)  # pylint: disable=protected-access
1168        self.assertEqual("foo/var6:0", v.name)
1169        self.evaluate(variables.global_variables_initializer())
1170
1171      w = _eager_safe_var_handle_op(
1172          dtype=v.dtype.base_dtype,
1173          shape=v.get_shape(),
1174          shared_name="foo/var6",
1175          # Needed in Eager since we get a unique container name by default.
1176          container=ops.get_default_graph()._container)
1177      w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
1178      self.assertEqual(300.0, self.evaluate(w_read))
1179
1180  @test_util.run_in_graph_and_eager_modes
1181  def testShape(self):
1182    v = resource_variable_ops.ResourceVariable(
1183        name="var4", initial_value=array_ops.ones(shape=[10, 20, 35]))
1184    self.assertEqual("(10, 20, 35)", str(v.shape))
1185    self.assertEqual("(10, 20, 35)", str(v.get_shape()))
1186    self.assertEqual("(10, 20, 35)", str(v.value().shape))
1187    self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
1188    if not context.executing_eagerly():
1189      self.assertEqual(
1190          "<unknown>",
1191          str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
1192
1193  @test_util.run_deprecated_v1
1194  def testSetInitialValue(self):
1195    with self.cached_session():
1196      # Initialize variable with a value different from the initial value passed
1197      # in the constructor.
1198      v = resource_variable_ops.ResourceVariable(2.0)
1199      v.initializer.run(feed_dict={v.initial_value: 3.0})
1200      self.assertEqual(3.0, v.value().eval())
1201
1202  @test_util.run_v1_only("b/120545219")
1203  def testControlFlowInitialization(self):
1204    """Expects an error if an initializer is in a control-flow scope."""
1205
1206    def cond(i, _):
1207      return i < 10
1208
1209    def body(i, _):
1210      zero = array_ops.zeros([], dtype=dtypes.int32)
1211      v = resource_variable_ops.ResourceVariable(initial_value=zero)
1212      return (i + 1, v.read_value())
1213
1214    with self.assertRaisesRegex(ValueError, "initial_value"):
1215      control_flow_ops.while_loop(cond, body, [0, 0])
1216
1217  def testVariableEager(self):
1218    with context.eager_mode():
1219      init = array_ops.ones(shape=[10, 20, 35], dtype=dtypes.int32)
1220      constraint = lambda x: x
1221      with ops.name_scope("foo", skip_on_eager=False):
1222        v = resource_variable_ops.ResourceVariable(
1223            name="var7",
1224            initial_value=init,
1225            caching_device="cpu:0",
1226            constraint=constraint)
1227      # Test properties
1228      self.assertEqual(dtypes.int32, v.dtype)
1229      self.assertEqual("foo/var7:0", v.name)
1230      self.assertAllEqual([10, 20, 35], v.shape.as_list())
1231      self.assertIsInstance(v.handle, ops.EagerTensor)
1232      self.assertEqual(constraint, v.constraint)
1233      self.assertAllEqual(init.numpy(), v.read_value().numpy())
1234      self.assertAllEqual(init.numpy(), v.value().numpy())
1235
1236      # Callable init.
1237      callable_init = lambda: init * 2
1238      v2 = resource_variable_ops.ResourceVariable(
1239          initial_value=callable_init, name="var7")
1240      self.assertEqual("var7:0", v2.name)
1241      self.assertAllEqual(2 * init.numpy(), v2.read_value().numpy())
1242
1243      # Test assign_add.
1244      new_v2_val = v2.assign_add(v.read_value())
1245      self.assertAllEqual(v.read_value().numpy() * 3, new_v2_val.numpy())
1246
1247      # Test assign_sub.
1248      new_v2_val = v2.assign_sub(v.read_value())
1249      self.assertAllEqual(v.read_value().numpy() * 2, new_v2_val.numpy())
1250
1251      # Test assign.
1252      v2.assign(v.read_value())
1253      self.assertAllEqual(v.read_value().numpy(), v2.read_value().numpy())
1254
1255      # Test load
1256      v2.load(2 * v.read_value())
1257      self.assertAllEqual(2 * v.read_value().numpy(), v2.read_value().numpy())
1258
1259      # Test convert_to_tensor
1260      t = ops.convert_to_tensor(v)
1261      self.assertAllEqual(t.numpy(), v.read_value().numpy())
1262
1263      # Test operations
1264      self.assertAllEqual((v * 2).numpy(), (v + v).numpy())
1265
1266  def testNumpyDotArray(self):
1267    with context.eager_mode():
1268      # Scalars use a separate code path.
1269      v1 = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
1270                                                  name="v1")
1271      self.assertEqual(1, np.array(v1))
1272
1273      v2 = resource_variable_ops.ResourceVariable(initial_value=lambda: [1, 2],
1274                                                  name="v2")
1275      self.assertAllEqual(v2.read_value().numpy(), np.array(v2))
1276      self.assertAllEqual([1, 2], np.array(v2))
1277
1278  def testContainerEager(self):
1279    with context.eager_mode():
1280      v1 = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
1281                                                  name="same")
1282      with ops.container("different"):
1283        v2 = resource_variable_ops.ResourceVariable(initial_value=lambda: 0,
1284                                                    name="same")
1285      v2.assign(2)
1286      self.assertEqual(1, v1.read_value().numpy())
1287      self.assertEqual(2, v2.read_value().numpy())
1288
1289  def testDestruction(self):
1290    with context.eager_mode():
1291      var = resource_variable_ops.ResourceVariable(initial_value=1.0,
1292                                                   name="var8")
1293      var_handle = test_ops.make_weak_resource_handle(var._handle)
1294      del var
1295      with self.assertRaisesRegex(errors.NotFoundError,
1296                                  r"Resource .* does not exist."):
1297        resource_variable_ops.destroy_resource_op(var_handle,
1298                                                  ignore_lookup_error=False)
1299
1300  def testScatterUpdate(self):
1301    with context.eager_mode():
1302      v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
1303      state_ops.scatter_update(v, [1], [3.0])
1304      self.assertAllEqual([1.0, 3.0], v.numpy())
1305
1306  def testScatterAddStateOps(self):
1307    with context.eager_mode():
1308      v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="add")
1309      state_ops.scatter_add(v, [1], [3])
1310      self.assertAllEqual([1.0, 5.0], v.numpy())
1311
1312  def testScatterSubStateOps(self):
1313    with context.eager_mode():
1314      v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="sub")
1315      state_ops.scatter_sub(v, [1], [3])
1316      self.assertAllEqual([1.0, -1.0], v.numpy())
1317
1318  def testScatterUpdateVariant(self):
1319    with context.eager_mode():
1320      v = resource_variable_ops.ResourceVariable([
1321          list_ops.empty_tensor_list(
1322              element_dtype=dtypes.float32, element_shape=[])
1323      ])
1324      v.scatter_update(
1325          indexed_slices.IndexedSlices(
1326              list_ops.tensor_list_from_tensor([1., 2.], element_shape=[]), 0))
1327      self.assertAllEqual(
1328          list_ops.tensor_list_get_item(v[0], 0, element_dtype=dtypes.float32),
1329          1.)
1330
1331  def testGroupDoesntForceRead(self):
1332    with ops.Graph().as_default():
1333      v = resource_variable_ops.ResourceVariable(1.0)
1334      assign = v.assign_add(1.0)
1335      g = control_flow_ops.group([assign])
1336      self.assertEqual(g.control_inputs[0].type, "AssignAddVariableOp")
1337
1338  def testScatterNdAddStateOps(self):
1339    with context.eager_mode():
1340      v = resource_variable_ops.ResourceVariable(
1341          [1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.float32, name="add")
1342      indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
1343      updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
1344      expected = np.array([1, 13, 3, 14, 14, 6, 7, 20])
1345      state_ops.scatter_nd_add(v, indices, updates)
1346      self.assertAllClose(expected, v.numpy())
1347
1348  @test_util.run_in_graph_and_eager_modes
1349  def testUnreadVariableInsideFunction(self):
1350    v = resource_variable_ops.ResourceVariable(1.0)
1351
1352    @def_function.function
1353    def assign():
1354      v.assign(1.0)
1355
1356    graph = assign.get_concrete_function().graph
1357    self.assertTrue(all(x.type != "ReadVariableOp"
1358                        for x in graph.get_operations()))
1359
1360  def testScatterNdSubStateOps(self):
1361    with context.eager_mode():
1362      v = resource_variable_ops.ResourceVariable(
1363          [1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.float32, name="sub")
1364      indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
1365      updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
1366      expected = np.array([1, -9, 3, -6, -4, 6, 7, -4])
1367      state_ops.scatter_nd_sub(v, indices, updates)
1368      self.assertAllClose(expected, v.numpy())
1369
1370  def testScatterUpdateCast(self):
1371    with context.eager_mode():
1372      v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
1373      state_ops.scatter_update(v, [1], [3])
1374      self.assertAllEqual([1.0, 3.0], v.numpy())
1375
1376  @test_util.run_in_graph_and_eager_modes
1377  def testScatterUpdateInvalidArgs(self):
1378    v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
1379    # The exact error and message differ between graph construction (where the
1380    # error is realized during shape inference at graph construction time),
1381    # eager execution (where the error is realized during kernel execution),
1382    # and XLA auto-clustering execution (where the error is realized in the xla
1383    # op kernel) which is triggered when running in eager op as function mode.
1384    with self.assertRaisesRegex(Exception, r"shape.*2.*3|RET_CHECK failure"):
1385      state_ops.scatter_update(v, [0, 1], [0, 1, 2])
1386
1387  @test_util.disable_xla("b/208334252")  # XLA doesn't have a deterministic impl
1388  def testScatterAddDeterministic(self):
1389    with context.eager_mode(), test_util.deterministic_ops():
1390      # Normally a nondeterministic codepath occurs when the variable has at
1391      # least 1024 elements. Test that op determinism ensures the op is
1392      # deterministc.
1393      v = resource_variable_ops.ResourceVariable(array_ops.zeros([1024]))
1394      delta = ops.IndexedSlices(
1395          values=np.random.normal(size=(1_000_000,)),
1396          indices=array_ops.zeros((1_000_000,), dtype=np.int32),
1397          dense_shape=(1024,))
1398      v.scatter_add(delta)
1399      for _ in range(5):
1400        v2 = resource_variable_ops.ResourceVariable(array_ops.zeros([1024]))
1401        v2.scatter_add(delta)
1402        self.assertAllEqual(v, v2)
1403
1404  @test_util.run_in_graph_and_eager_modes
1405  def testAssignIncompatibleShape(self):
1406    v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
1407    self.evaluate(v.initializer)
1408    pattern = re.compile("shapes must be equal", re.IGNORECASE)
1409    with self.assertRaisesRegex(Exception, pattern):
1410      self.evaluate(v.assign_add(1))
1411
1412  @test_util.run_in_graph_and_eager_modes
1413  @test_util.run_v1_only("b/120545219")
1414  def testCopyToGraphUninitialized(self):
1415    v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
1416    copy_to_graph = ops.Graph()
1417    with copy_to_graph.as_default():  # Intentionally testing v1 behavior
1418      copied = resource_variable_ops.copy_to_graph_uninitialized(v)
1419      self.assertEqual(v.name, copied.name)
1420      self.assertIsNone(copied.initializer)
1421
1422  def create_variant_shape_and_type_data(self):
1423    variant_shape_and_type_data = (
1424        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData())
1425    variant_shape_and_type_data.is_set = True
1426    stored_shape = tensor_shape.TensorShape([None, 4]).as_proto()
1427    stored_dtype = dtypes.float32.as_datatype_enum
1428    # NOTE(ebrevdo): shape_and_type lacks append() in some versions of protobuf.
1429    variant_shape_and_type_data.shape_and_type.extend([
1430        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
1431            shape=stored_shape,
1432            dtype=stored_dtype,
1433            type=full_type_pb2.FullTypeDef())
1434    ])
1435    return variant_shape_and_type_data
1436
1437  @def_function.function
1438  def create_constant_variant(self, value):
1439    value = constant_op.constant(
1440        tensor_pb2.TensorProto(
1441            dtype=dtypes.variant.as_datatype_enum,
1442            tensor_shape=tensor_shape.TensorShape([]).as_proto(),
1443            variant_val=[
1444                tensor_pb2.VariantTensorDataProto(
1445                    # Match registration in variant_op_registry.cc
1446                    type_name=b"int",
1447                    metadata=np.array(value, dtype=np.int32).tobytes())
1448            ]))
1449    return value
1450
1451  # TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create
1452  # EagerTensor constants with TensorProto inputs.
1453  @test_util.disable_tfrt("Does not support tf.Const in lowering.")
1454  @test_util.run_in_graph_and_eager_modes()
1455  def testVariantInitializer(self):
1456    variant_shape_and_type_data = self.create_variant_shape_and_type_data()
1457    value = self.create_constant_variant(3)
1458    initializer = array_ops.fill([3], value)
1459    resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
1460        initializer, variant_shape_and_type_data,
1461        graph_mode=not context.executing_eagerly())
1462    v = resource_variable_ops.ResourceVariable(initializer)
1463    read = array_ops.identity(v)
1464    read_variant_shape_and_type = (
1465        resource_variable_ops.get_eager_safe_handle_data(read))
1466    self.assertEqual(
1467        read_variant_shape_and_type, variant_shape_and_type_data)
1468    gather = v.sparse_read([0])
1469    gather_variant_shape_and_type = (
1470        resource_variable_ops.get_eager_safe_handle_data(gather))
1471    self.assertEqual(
1472        gather_variant_shape_and_type, variant_shape_and_type_data)
1473    # Make sure initializer runs.
1474    if not context.executing_eagerly():
1475      self.evaluate(v.initializer)
1476      self.evaluate(read.op)
1477      self.evaluate(gather.op)
1478
1479  @parameterized.parameters([
1480      # batch_dims=0 (equivalent to tf.gather)
1481      dict(  # 2D indices
1482          batch_dims=0,
1483          params=[6, 7, 8, 9],
1484          indices=[[2, 1], [0, 3]],
1485          expected=[[8, 7], [6, 9]]),
1486      dict(  # 3D indices
1487          batch_dims=0,
1488          params=[6, 7, 8, 9],
1489          indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]],
1490          expected=[[[9, 7], [8, 6]], [[6, 9], [8, 8]]]),
1491      dict(  # 4D indices
1492          batch_dims=0,
1493          params=[8, 9],
1494          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
1495                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
1496          expected=[[[[8, 9], [9, 8]], [[8, 8], [9, 9]]],
1497                    [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]),
1498
1499      # batch_dims=indices.shape.ndims - 1 (equivalent to
1500      # tf.compat.v1.batch_gather)
1501      dict(  # 2D indices (1 batch dim)
1502          batch_dims=1,
1503          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
1504          indices=[[2, 1], [0, 3]],
1505          expected=[[12, 11], [20, 23]]),
1506      dict(  # 3D indices (2 batch dims)
1507          batch_dims=2,
1508          params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]],
1509          indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
1510          expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]),
1511      dict(  # 2D indices (1 batch dim)
1512          batch_dims=1,
1513          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
1514          indices=[[2, 1], [0, 3]],
1515          expected=[[12, 11], [20, 23]]),
1516      dict(  # 3D indices (2 batch dims)
1517          batch_dims=2,
1518          params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]],
1519          indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
1520          expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]),
1521
1522      # 0 < batch_dims < indices.shape.ndims - 1
1523      dict(  # 3D indices (1 batch dim)
1524          batch_dims=1,
1525          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
1526          indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]],
1527          expected=[[[13, 11], [12, 10]], [[20, 23], [22, 22]]]),
1528      dict(  # 4D indices (1 batch dim)
1529          batch_dims=1,
1530          params=[[6, 7], [8, 9]],
1531          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
1532                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
1533          expected=[[[[6, 7], [7, 6]], [[6, 6], [7, 7]]],
1534                    [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]),
1535      dict(  # 4D indices (2 batch dims)
1536          batch_dims=2,
1537          params=[[[2, 3], [4, 5]], [[6, 7], [8, 9]]],
1538          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
1539                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
1540          expected=[[[[2, 3], [3, 2]], [[4, 4], [5, 5]]],
1541                    [[[7, 7], [6, 6]], [[8, 9], [9, 8]]]]),
1542  ])
1543  @test_util.run_in_graph_and_eager_modes
1544  def testGatherWithBatchDims(self, params, indices, batch_dims, expected):
1545    var = resource_variable_ops.ResourceVariable(params, name="var0")
1546    with ops.control_dependencies([var.initializer]):
1547      result = resource_variable_ops.resource_gather(
1548          var.handle, indices, dtype=var.dtype, batch_dims=batch_dims)
1549    self.assertAllEqual(expected, result)
1550
1551  @parameterized.parameters([
1552      dict(
1553          params_shape=[2, 3, 4, 5, 6, 7],
1554          indices_shape=[2, 3, 8, 9, 10],
1555          batch_dims=0,
1556          output_shape=[2, 3, 8, 9, 10, 3, 4, 5, 6, 7]
1557          # = indices.shape + params.shape[1:]
1558      ),
1559      dict(
1560          params_shape=[2, 3, 4, 5, 6, 7],
1561          indices_shape=[2, 3, 8, 9, 10],
1562          batch_dims=1,
1563          output_shape=[2, 3, 8, 9, 10, 4, 5, 6, 7]
1564          # = params.shape[:1] + indices.shape[1:] + params.shape[2:]
1565      ),
1566      dict(
1567          params_shape=[2, 3, 4, 5, 6, 7],
1568          indices_shape=[2, 3, 8, 9, 10],
1569          batch_dims=2,
1570          output_shape=[2, 3, 8, 9, 10, 5, 6, 7]
1571          # = params.shape[:2] + indices.shape[2:] + params.shape[3:]
1572      ),
1573      dict(
1574          params_shape=[2, 3, 4, 5, 6, 7],
1575          indices_shape=[2, 3, 4, 9, 10],
1576          batch_dims=3,
1577          output_shape=[2, 3, 4, 9, 10, 6, 7]
1578          # = params.shape[:3] + indices.shape[3:] + params.shape[4:]
1579      ),
1580      dict(
1581          params_shape=[2, 3, 4, 5, 6, 7],
1582          indices_shape=[2, 3, 4, 5, 10],
1583          batch_dims=4,
1584          output_shape=[2, 3, 4, 5, 10, 7]
1585          # = params.shape[:4] + indices.shape[4:] + params.shape[5:]
1586      ),
1587  ])
1588  @test_util.run_in_graph_and_eager_modes
1589  def testGatherWithBatchDimsMatchesTensor(self, params_shape, indices_shape,
1590                                           batch_dims, output_shape):
1591    """Checks that gather with batch_dims returns the correct shape."""
1592    # Generate a `params` tensor with the indicated shape.
1593    params_size = np.prod(params_shape)
1594    params = np.reshape(np.arange(params_size, dtype=np.int32), params_shape)
1595
1596    # Generate an `indices` tensor with the indicated shape, where each index
1597    # is within the appropriate range.
1598    indices_size = np.prod(indices_shape)
1599    indices = np.reshape(np.arange(indices_size, dtype=np.int32), indices_shape)
1600    indices = indices % params_shape[batch_dims]
1601
1602    var = resource_variable_ops.ResourceVariable(params, name="var0")
1603    with ops.control_dependencies([var.initializer]):
1604      expected = array_ops.gather(
1605          var.read_value(), indices, batch_dims=batch_dims)
1606      result = resource_variable_ops.resource_gather(
1607          var.handle, indices, dtype=var.dtype, batch_dims=batch_dims)
1608
1609    self.assertAllEqual(output_shape, result.shape.as_list())
1610    self.assertAllEqual(expected, result)
1611
1612  @parameterized.parameters([
1613      dict(dtype=dtypes.bool),
1614      dict(dtype=dtypes.int64),
1615      dict(dtype=dtypes.half),
1616      dict(dtype=dtypes.float32),
1617      dict(dtype=dtypes.double),
1618  ])
1619  @test_util.run_gpu_only
1620  @test_util.run_in_graph_and_eager_modes
1621  def testGatherWithDTypes(self, dtype):
1622    if dtype == dtypes.bool:
1623      params = constant_op.constant([False, True, False, True])
1624      expected = constant_op.constant([[False, True], [False, True]])
1625    else:
1626      params = constant_op.constant([6, 7, 8, 9], dtype=dtype)
1627      expected = constant_op.constant([[8, 7], [6, 9]], dtype=dtype)
1628    indices = constant_op.constant([[2, 1], [0, 3]])
1629    var = resource_variable_ops.ResourceVariable(params, name="var0")
1630    with ops.control_dependencies([var.initializer]):
1631      result = resource_variable_ops.resource_gather(
1632          var.handle, indices, dtype=dtype)
1633    self.assertAllEqual(expected, result)
1634
1635  @test_util.run_v2_only
1636  def testUninitializedVariableMemoryUsage(self):
1637    if test_util.is_gpu_available():
1638      # TODO(allenl): Investigate possible GPU-specific memory leaks
1639      self.skipTest("Disabled when a GPU is available")
1640    # TODO(kkb): Python memory checker complains continuous `weakref`
1641    # allocations, investigate.
1642    if memory_checker.CppMemoryChecker is None:
1643      self.skipTest("Requires the C++ memory checker")
1644
1645    def _create_and_delete_variable():
1646      resource_variable_ops.UninitializedVariable(
1647          shape=[100, 100],
1648          dtype=dtypes.float32)
1649
1650    _create_and_delete_variable()
1651    checker = memory_checker.CppMemoryChecker(
1652        "ResourceVariableOps.testUninitializedVariableMemoryUsage")
1653    for _ in range(2):
1654      _create_and_delete_variable()
1655      checker.record_snapshot()
1656    checker.stop()
1657    checker.report()
1658    checker.assert_no_leak_if_all_possibly_except_one()
1659
1660  @test_util.run_v2_only
1661  def testIterateVariable(self):
1662    v = variables.Variable([1., 2.])
1663    self.assertAllClose([1., 2.], list(iter(v)))
1664
1665
1666if __name__ == "__main__":
1667  test.main()
1668