1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.resource_variable_ops.""" 16import copy 17import gc 18import os 19import pickle 20import re 21 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.core.framework import full_type_pb2 26from tensorflow.core.framework import tensor_pb2 27from tensorflow.python.compat import compat as forward_compat 28from tensorflow.python.eager import backprop 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import cpp_shape_inference_pb2 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import indexed_slices 36from tensorflow.python.framework import memory_checker 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.framework import test_ops 41from tensorflow.python.framework import test_util 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import custom_gradient 45from tensorflow.python.ops import gradients_impl 46from tensorflow.python.ops import handle_data_util 47from tensorflow.python.ops import init_ops 48from tensorflow.python.ops import list_ops 49from tensorflow.python.ops import math_ops 50from tensorflow.python.ops import resource_variable_ops 51from tensorflow.python.ops import state_ops 52from tensorflow.python.ops import variable_scope 53from tensorflow.python.ops import variables 54from tensorflow.python.platform import test 55from tensorflow.python.training import momentum 56from tensorflow.python.training import saver 57from tensorflow.python.training import training_util 58from tensorflow.python.util import compat 59 60 61def _eager_safe_var_handle_op(*args, **kwargs): 62 # When running in eager mode the `shared_name` should be set to the 63 # `anonymous_name` to avoid spurious sharing issues. The runtime generates a 64 # unique name on our behalf when the reserved `anonymous_name` is used as the 65 # `shared_name`. 66 if context.executing_eagerly() and "shared_name" not in kwargs: 67 kwargs["shared_name"] = context.anonymous_name() 68 return resource_variable_ops.var_handle_op(*args, **kwargs) 69 70 71@test_util.with_eager_op_as_function 72@test_util.with_control_flow_v2 73class ResourceVariableOpsTest(test_util.TensorFlowTestCase, 74 parameterized.TestCase): 75 76 def tearDown(self): 77 gc.collect() 78 # This will only contain uncollectable garbage, i.e. reference cycles 79 # involving objects with __del__ defined. 80 self.assertEmpty(gc.garbage) 81 super(ResourceVariableOpsTest, self).tearDown() 82 83 @test_util.run_deprecated_v1 84 def testHandleDtypeShapeMatch(self): 85 with self.cached_session(): 86 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[]) 87 with self.assertRaises(ValueError): 88 resource_variable_ops.assign_variable_op( 89 handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() 90 with self.assertRaises(ValueError): 91 resource_variable_ops.assign_variable_op(handle, 92 constant_op.constant( 93 [0], 94 dtype=dtypes.int32)).run() 95 resource_variable_ops.assign_variable_op(handle, 96 constant_op.constant( 97 0, 98 dtype=dtypes.int32)).run() 99 100 @test_util.run_gpu_only 101 def testGPUInt64(self): 102 with context.eager_mode(), context.device("gpu:0"): 103 v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int64) 104 self.assertAllEqual(1, v.numpy()) 105 106 @test_util.run_gpu_only 107 def testGPUBfloat16(self): 108 with context.eager_mode(), ops.device("gpu:0"): 109 v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.bfloat16) 110 self.assertEqual("/job:localhost/replica:0/task:0/device:GPU:0", 111 v.device) 112 self.assertAllEqual(1, v.numpy()) 113 114 def testEagerNameNotIdentity(self): 115 with context.eager_mode(): 116 v0 = resource_variable_ops.ResourceVariable(1.0, name="a") 117 v1 = resource_variable_ops.ResourceVariable(2.0, name="a") 118 self.assertAllEqual(v0.numpy(), 1.0) 119 self.assertAllEqual(v1.numpy(), 2.0) 120 121 def testEagerNameNotNeeded(self): 122 with context.eager_mode(): 123 v0 = resource_variable_ops.ResourceVariable(1.0) 124 self.assertAllEqual(v0.numpy(), 1.0) 125 126 def testReadVariableDtypeMismatchEager(self): 127 with context.eager_mode(): 128 handle = _eager_safe_var_handle_op( 129 dtype=dtypes.int32, shape=[1], name="foo") 130 resource_variable_ops.assign_variable_op(handle, 1) 131 # The error message varies depending on whether it is being raised 132 # by the kernel or shape inference. The shape inference code path can 133 # be reached when running in eager op as function mode where each op 134 # is wrapped in a tf.function. 135 with self.assertRaisesRegex( 136 errors.InvalidArgumentError, 137 r"Trying to read variable with wrong dtype. " 138 r"Expected (float|int32) got (int32|float)"): 139 _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32) 140 141 def testEagerInitializedValue(self): 142 with context.eager_mode(): 143 variable = resource_variable_ops.ResourceVariable(1.0, name="eager-init") 144 self.assertAllEqual(variable.numpy(), 1.0) 145 self.assertAllEqual(variable.initialized_value().numpy(), 1.0) 146 147 def testInitializeVariableUsingInitializedValue(self): 148 var1 = resource_variable_ops.ResourceVariable(1.0, name="var1") 149 var2 = resource_variable_ops.ResourceVariable(var1.initialized_value(), 150 name="var2") 151 self.assertAllEqual(var2.initialized_value(), 1.0) 152 153 def testEagerBool(self): 154 with context.eager_mode(): 155 v = resource_variable_ops.ResourceVariable(False, name="bool_test") 156 self.assertAllEqual(bool(v), False) 157 158 def testEagerDeepCopy(self): 159 with context.eager_mode(): 160 init_value = np.ones((4, 4, 4)) 161 variable = resource_variable_ops.ResourceVariable( 162 init_value, 163 name="init", 164 synchronization=variables.VariableSynchronization.ON_READ, 165 aggregation=variables.VariableAggregation.SUM) 166 167 copied_variable = copy.deepcopy(variable) 168 self.assertEqual(variable.name, copied_variable.name) 169 self.assertEqual(variable.shape, copied_variable.shape) 170 self.assertEqual(variable.device, copied_variable.device) 171 self.assertEqual(variable.synchronization, 172 copied_variable.synchronization) 173 self.assertEqual(variable.aggregation, copied_variable.aggregation) 174 175 # The copied variable should have the same value as the original. 176 self.assertAllEqual(variable.numpy(), copied_variable.numpy()) 177 178 # Updates to the copy should not be reflected in the original. 179 copied_variable.assign(4 * np.ones((4, 4, 4))) 180 self.assertNotAllEqual(variable.numpy(), copied_variable.numpy()) 181 182 @test_util.run_deprecated_v1 183 def testGraphDeepCopy(self): 184 with self.cached_session(): 185 init_value = np.ones((4, 4, 4)) 186 variable = resource_variable_ops.ResourceVariable(init_value, 187 name="init") 188 with self.assertRaises(NotImplementedError): 189 copy.deepcopy(variable) 190 191 @test_util.run_in_graph_and_eager_modes 192 def testStridedSliceAssign(self): 193 v = resource_variable_ops.ResourceVariable([1.0, 2.0]) 194 self.evaluate(variables.global_variables_initializer()) 195 self.evaluate(v[0].assign(2.0)) 196 self.assertAllEqual(self.evaluate(v), [2.0, 2.0]) 197 198 @test_util.run_in_graph_and_eager_modes 199 def testVariableShape(self): 200 v = resource_variable_ops.ResourceVariable([1., 1.]) 201 vshape = resource_variable_ops.variable_shape(v.handle) 202 self.assertAllEqual( 203 tensor_util.constant_value(vshape), 204 [2]) 205 if not context.executing_eagerly(): 206 self.assertEqual("Const", vshape.op.type) 207 208 @test_util.run_deprecated_v1 209 def testDifferentAssignGraph(self): 210 with ops.Graph().as_default(): 211 v = resource_variable_ops.ResourceVariable(1.0) 212 ops.reset_default_graph() 213 v.assign(2.0) # Note: this fails if we run convert_to_tensor on not the 214 # variable graph. 215 216 @test_util.run_deprecated_v1 217 def testFetchHandle(self): 218 with self.cached_session(): 219 handle = _eager_safe_var_handle_op( 220 dtype=dtypes.int32, shape=[1], name="foo") 221 self.assertNotEmpty(self.evaluate(handle)) 222 223 @test_util.run_deprecated_v1 224 def testCachedValueReadBeforeWrite(self): 225 with self.cached_session() as sess: 226 v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0") 227 self.evaluate(v.initializer) 228 value, _ = sess.run([v, v.assign_add(1.0)]) 229 self.assertAllEqual(value, 0.0) 230 231 def testAssignVariableDtypeMismatchEager(self): 232 with context.eager_mode(): 233 handle = _eager_safe_var_handle_op( 234 dtype=dtypes.int32, shape=[1], name="foo") 235 resource_variable_ops.assign_variable_op( 236 handle, constant_op.constant([1])) 237 # The error message varies depending on whether it is being raised 238 # by the kernel or shape inference. The shape inference code path can 239 # be reached when running in eager op as function mode where each op 240 # is wrapped in a tf.function. 241 with self.assertRaisesRegex( 242 errors.InvalidArgumentError, r"Trying to .* variable with wrong " 243 r"dtype. Expected int32 got float"): 244 resource_variable_ops.assign_variable_op( 245 handle, constant_op.constant([1.], dtype=dtypes.float32)) 246 247 def testRepr(self): 248 with context.eager_mode(): 249 v = resource_variable_ops.ResourceVariable(1) 250 text = "%r" % v 251 self.assertEqual( 252 "<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=1>", text) 253 254 def testReprUnavailable(self): 255 with context.eager_mode(): 256 v = resource_variable_ops.ResourceVariable(1) 257 258 # Monkey-patch this variable to not have an available value 259 def broken_read(): 260 raise ValueError("This doesn't work") 261 262 v.read_value = broken_read 263 text = "%r" % v 264 self.assertEqual("<tf.Variable 'Variable:0' shape=() dtype=int32," 265 " numpy=<unavailable>>", text) 266 267 def testFormatResourceHandle(self): 268 with context.eager_mode(): 269 handle = _eager_safe_var_handle_op( 270 dtype=dtypes.int32, shape=[1], name="foo") 271 self.assertIn("<ResourceHandle", str(handle)) 272 self.assertIn("<ResourceHandle", repr(handle)) 273 274 @test_util.run_in_graph_and_eager_modes 275 def testDtypeSurvivesIdentity(self): 276 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[]) 277 id_handle = array_ops.identity(handle) 278 self.evaluate(resource_variable_ops.assign_variable_op( 279 id_handle, constant_op.constant(0, dtype=dtypes.int32))) 280 281 def testUnreadOpName(self): 282 v = resource_variable_ops.ResourceVariable(1.0) 283 self.assertNotEqual(v.name, v.assign_add(1.0).name) 284 285 @test_util.run_in_graph_and_eager_modes 286 def testCreateRead(self): 287 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[]) 288 self.evaluate(resource_variable_ops.assign_variable_op( 289 handle, constant_op.constant(1, dtype=dtypes.int32))) 290 value = self.evaluate( 291 resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) 292 self.assertAllEqual(1, value) 293 294 @test_util.run_in_graph_and_eager_modes 295 def testManyAssigns(self): 296 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[]) 297 create = resource_variable_ops.assign_variable_op( 298 handle, constant_op.constant(1, dtype=dtypes.int32)) 299 with ops.control_dependencies([create]): 300 first_read = resource_variable_ops.read_variable_op( 301 handle, dtype=dtypes.int32) 302 with ops.control_dependencies([first_read]): 303 write = resource_variable_ops.assign_variable_op( 304 handle, constant_op.constant(2, dtype=dtypes.int32)) 305 with ops.control_dependencies([write]): 306 second_read = resource_variable_ops.read_variable_op( 307 handle, dtype=dtypes.int32) 308 f, s = self.evaluate([first_read, second_read]) 309 self.assertEqual(f, 1) 310 self.assertEqual(s, 2) 311 312 @test_util.run_in_graph_and_eager_modes 313 def testAssignAdd(self): 314 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[]) 315 self.evaluate(resource_variable_ops.assign_variable_op( 316 handle, constant_op.constant(1, dtype=dtypes.int32))) 317 self.evaluate(resource_variable_ops.assign_add_variable_op( 318 handle, constant_op.constant(1, dtype=dtypes.int32))) 319 read = self.evaluate( 320 resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) 321 self.assertEqual(read, 2) 322 323 @test_util.run_in_graph_and_eager_modes 324 def testScatterAdd(self): 325 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 326 self.evaluate( 327 resource_variable_ops.assign_variable_op( 328 handle, constant_op.constant([[1]], dtype=dtypes.int32))) 329 self.evaluate( 330 resource_variable_ops.resource_scatter_add( 331 handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) 332 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 333 self.assertEqual(self.evaluate(read), [[3]]) 334 335 @test_util.run_in_graph_and_eager_modes 336 def testGradientGatherNd(self): 337 v = resource_variable_ops.ResourceVariable( 338 np.random.uniform(size=[2, 2]), dtype=dtypes.float32) 339 340 with backprop.GradientTape() as tape: 341 l = array_ops.gather_nd(v, [[1, 1]]) 342 l = math_ops.reduce_sum(l) 343 344 grads = tape.gradient(l, v) 345 self.evaluate(variables.global_variables_initializer()) 346 self.assertAllEqual(self.evaluate(grads), [[0., 0.], [0., 1.]]) 347 348 @test_util.run_deprecated_v1 349 def testDefaultGradientDtype(self): 350 v = resource_variable_ops.ResourceVariable( 351 np.random.uniform(size=[2, 2]), dtype=dtypes.float64) 352 353 c = constant_op.constant(1.) 354 identity = array_ops.identity_n([c, v.handle]) 355 # TODO(b/137403775): Remove this. 356 handle_data_util.copy_handle_data(v.handle, identity[1]) 357 358 g = gradients_impl.gradients(identity[0], [c, v.handle]) 359 self.assertEqual(g[1].dtype, dtypes.float64) 360 self.evaluate(variables.global_variables_initializer()) 361 self.assertAllEqual(g[1], [[0., 0.], [0., 0.]]) 362 363 @test_util.run_deprecated_v1 364 def testUnconnectedGradientZeros(self): 365 b = resource_variable_ops.ResourceVariable(initial_value=[[3., 4.]]) 366 c = constant_op.constant(0.) 367 g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0] 368 self.assertAllEqual(g.shape.as_list(), [1, 2]) 369 370 @test_util.run_deprecated_v1 371 def testGradientCondInWhileLoop(self): 372 v = resource_variable_ops.ResourceVariable(initial_value=1.0) 373 def cond(i, unused_x): 374 return i < 1 375 376 def body(i, x): 377 def true(): 378 return x + v 379 def false(): 380 return 2.0 * v 381 return i + 1, control_flow_ops.cond(i > 0, true, false) 382 383 _, x = control_flow_ops.while_loop(cond, body, [0, 0.0]) 384 # Computing gradients does not produce an exception: 385 g = gradients_impl.gradients(x, v) 386 self.evaluate(variables.global_variables_initializer()) 387 # Only the false branch is taken so the gradient is 2. 388 self.assertAllEqual(g[0], 2.0) 389 390 @test_util.run_in_graph_and_eager_modes 391 def testGradientGatherNdIndexedSlices(self): 392 v = resource_variable_ops.ResourceVariable( 393 np.random.uniform(size=[2, 2]), dtype=dtypes.float32) 394 395 with backprop.GradientTape() as tape: 396 l = array_ops.gather_nd(v, [[1], [1]]) 397 l = math_ops.reduce_sum(l) 398 399 grads = tape.gradient(l, v) 400 self.evaluate(variables.global_variables_initializer()) 401 self.assertAllEqual(self.evaluate(grads.values), [[1., 1.], [1., 1.]]) 402 403 @test_util.run_in_graph_and_eager_modes 404 def testScatterSub(self): 405 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 406 self.evaluate( 407 resource_variable_ops.assign_variable_op( 408 handle, constant_op.constant([[1]], dtype=dtypes.int32))) 409 self.evaluate( 410 resource_variable_ops.resource_scatter_sub( 411 handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) 412 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 413 self.assertEqual(self.evaluate(read), [[-1]]) 414 415 @test_util.run_in_graph_and_eager_modes 416 def testScatterMul(self): 417 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 418 self.evaluate( 419 resource_variable_ops.assign_variable_op( 420 handle, constant_op.constant([[1]], dtype=dtypes.int32))) 421 self.evaluate( 422 resource_variable_ops.resource_scatter_mul( 423 handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) 424 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 425 self.assertEqual(self.evaluate(read), [[5]]) 426 427 def testEagerPickle(self): 428 with context.eager_mode(): 429 tmp_dir = self.get_temp_dir() 430 fname = os.path.join(tmp_dir, "var.pickle") 431 with open(fname, "wb") as f: 432 v = resource_variable_ops.ResourceVariable( 433 10.0, 434 dtype=dtypes.float16, 435 name="v") 436 pickle.dump(v, f) 437 438 with open(fname, "rb") as f: 439 new_v = pickle.load(f) 440 self.assertEqual(new_v.name, v.name) 441 self.assertEqual(new_v.shape, v.shape) 442 self.assertEqual(new_v.dtype, v.dtype) 443 self.assertEqual(new_v.trainable, v.trainable) 444 self.assertAllEqual(new_v.numpy(), v.numpy()) 445 446 @test_util.run_in_graph_and_eager_modes 447 def testScatterDiv(self): 448 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 449 self.evaluate( 450 resource_variable_ops.assign_variable_op( 451 handle, constant_op.constant([[6]], dtype=dtypes.int32))) 452 self.evaluate( 453 resource_variable_ops.resource_scatter_div( 454 handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) 455 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 456 self.assertEqual(self.evaluate(read), [[2]]) 457 458 def testUseResource(self): 459 v = variables.VariableV1(1.0, use_resource=True) 460 self.assertIsInstance(v, resource_variable_ops.ResourceVariable) 461 462 def testEagerNoUseResource(self): 463 with context.eager_mode(): 464 v = variables.Variable(1.0) 465 self.assertIsInstance(v, resource_variable_ops.ResourceVariable) 466 467 @test_util.run_in_graph_and_eager_modes 468 def testScatterMin(self): 469 with ops.device("cpu:0"): 470 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 471 self.evaluate( 472 resource_variable_ops.assign_variable_op(handle, 473 constant_op.constant( 474 [[6]], 475 dtype=dtypes.int32))) 476 self.evaluate( 477 resource_variable_ops.resource_scatter_min(handle, [0], 478 constant_op.constant( 479 [[3]], 480 dtype=dtypes.int32))) 481 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 482 self.assertEqual(self.evaluate(read), [[3]]) 483 484 def testMetagraph(self): 485 with ops.Graph().as_default(): 486 with variable_scope.variable_scope("foo", use_resource=True): 487 a = variable_scope.get_variable("a", initializer=10.0) 488 489 momentum.MomentumOptimizer( 490 learning_rate=0.001, momentum=0.1).minimize( 491 a, 492 colocate_gradients_with_ops=True, 493 global_step=training_util.get_or_create_global_step()) 494 495 graph = ops.get_default_graph() 496 meta_graph_def = saver.export_meta_graph(graph=graph) 497 498 with ops.Graph().as_default(): 499 saver.import_meta_graph(meta_graph_def, import_scope="") 500 meta_graph_two = saver.export_meta_graph(graph=graph) 501 self.assertEqual(meta_graph_def, meta_graph_two) 502 503 @test_util.run_in_graph_and_eager_modes 504 def testScatterMax(self): 505 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 506 self.evaluate( 507 resource_variable_ops.assign_variable_op( 508 handle, constant_op.constant([[6]], dtype=dtypes.int32))) 509 self.evaluate( 510 resource_variable_ops.resource_scatter_max( 511 handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) 512 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 513 self.assertEqual(self.evaluate(read), [[6]]) 514 515 @test_util.run_in_graph_and_eager_modes 516 def testScatterAddScalar(self): 517 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 518 self.evaluate( 519 resource_variable_ops.assign_variable_op( 520 handle, constant_op.constant([[1]], dtype=dtypes.int32))) 521 self.evaluate( 522 resource_variable_ops.resource_scatter_add( 523 handle, [0], constant_op.constant(2, dtype=dtypes.int32))) 524 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 525 self.assertEqual(self.evaluate(read), [[3]]) 526 527 @test_util.run_in_graph_and_eager_modes 528 def testScatterSubScalar(self): 529 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 530 self.evaluate( 531 resource_variable_ops.assign_variable_op( 532 handle, constant_op.constant([[1]], dtype=dtypes.int32))) 533 self.evaluate( 534 resource_variable_ops.resource_scatter_sub( 535 handle, [0], constant_op.constant(2, dtype=dtypes.int32))) 536 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 537 self.assertEqual(self.evaluate(read), [[-1]]) 538 539 @test_util.run_in_graph_and_eager_modes 540 def testScatterMulScalar(self): 541 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 542 self.evaluate( 543 resource_variable_ops.assign_variable_op( 544 handle, constant_op.constant([[1]], dtype=dtypes.int32))) 545 self.evaluate( 546 resource_variable_ops.resource_scatter_mul( 547 handle, [0], constant_op.constant(5, dtype=dtypes.int32))) 548 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 549 self.assertEqual(self.evaluate(read), [[5]]) 550 551 @test_util.run_in_graph_and_eager_modes 552 def testScatterDivScalar(self): 553 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 554 self.evaluate( 555 resource_variable_ops.assign_variable_op( 556 handle, constant_op.constant([[6]], dtype=dtypes.int32))) 557 self.evaluate( 558 resource_variable_ops.resource_scatter_div( 559 handle, [0], constant_op.constant(3, dtype=dtypes.int32))) 560 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 561 self.assertEqual(self.evaluate(read), [[2]]) 562 563 @test_util.run_in_graph_and_eager_modes 564 def testScatterMinScalar(self): 565 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 566 self.evaluate( 567 resource_variable_ops.assign_variable_op( 568 handle, constant_op.constant([[6]], dtype=dtypes.int32))) 569 self.evaluate( 570 resource_variable_ops.resource_scatter_min( 571 handle, [0], constant_op.constant(3, dtype=dtypes.int32))) 572 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 573 self.assertEqual(self.evaluate(read), [[3]]) 574 575 @test_util.run_in_graph_and_eager_modes 576 def testScatterMaxScalar(self): 577 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1]) 578 self.evaluate( 579 resource_variable_ops.assign_variable_op( 580 handle, constant_op.constant([[6]], dtype=dtypes.int32))) 581 self.evaluate( 582 resource_variable_ops.resource_scatter_max( 583 handle, [0], constant_op.constant(3, dtype=dtypes.int32))) 584 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) 585 self.assertEqual(self.evaluate(read), [[6]]) 586 587 @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64) 588 @test_util.run_in_graph_and_eager_modes 589 def testScatterAddVariableMethod(self, dtype): 590 v = resource_variable_ops.ResourceVariable([0.0, 1.5], 591 name="add", 592 dtype=dtype) 593 self.evaluate(variables.global_variables_initializer()) 594 self.evaluate( 595 v.scatter_add( 596 indexed_slices.IndexedSlices( 597 indices=[1], values=constant_op.constant([2.5], dtype=dtype)))) 598 self.assertAllCloseAccordingToType([0.0, 4.0], self.evaluate(v)) 599 600 @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64) 601 @test_util.run_in_graph_and_eager_modes 602 def testScatterSubVariableMethod(self, dtype): 603 v = resource_variable_ops.ResourceVariable([0.0, 2.5], 604 name="sub", 605 dtype=dtype) 606 self.evaluate(variables.global_variables_initializer()) 607 self.evaluate( 608 v.scatter_sub( 609 indexed_slices.IndexedSlices( 610 indices=[1], values=constant_op.constant([1.5], dtype=dtype)))) 611 self.assertAllCloseAccordingToType([0.0, 1.0], self.evaluate(v)) 612 613 @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64) 614 @test_util.run_in_graph_and_eager_modes 615 def testScatterMaxVariableMethod(self, dtype): 616 v = resource_variable_ops.ResourceVariable([0.0, 4.0], 617 name="max1", 618 dtype=dtype) 619 self.evaluate(variables.global_variables_initializer()) 620 self.evaluate( 621 v.scatter_max( 622 indexed_slices.IndexedSlices( 623 indices=[1], values=constant_op.constant([5.0], dtype=dtype)))) 624 self.assertAllCloseAccordingToType([0.0, 5.0], self.evaluate(v)) 625 626 v = resource_variable_ops.ResourceVariable([0.0, 3.5], 627 name="max2", 628 dtype=dtype) 629 self.evaluate(variables.global_variables_initializer()) 630 self.evaluate( 631 v.scatter_max( 632 indexed_slices.IndexedSlices( 633 indices=[1], values=constant_op.constant([2.0], dtype=dtype)))) 634 self.assertAllCloseAccordingToType([0.0, 3.5], self.evaluate(v)) 635 636 @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64) 637 @test_util.run_in_graph_and_eager_modes 638 def testScatterMinVariableMethod(self, dtype): 639 v = resource_variable_ops.ResourceVariable([0.0, 4.0], 640 name="min1", 641 dtype=dtype) 642 self.evaluate(variables.global_variables_initializer()) 643 self.evaluate( 644 v.scatter_min( 645 indexed_slices.IndexedSlices( 646 indices=[1], values=constant_op.constant([5.0], dtype=dtype)))) 647 self.assertAllCloseAccordingToType([0.0, 4.0], self.evaluate(v)) 648 649 v = resource_variable_ops.ResourceVariable([0.0, 3.5], 650 name="min2", 651 dtype=dtype) 652 self.evaluate(variables.global_variables_initializer()) 653 self.evaluate( 654 v.scatter_min( 655 indexed_slices.IndexedSlices( 656 indices=[1], values=constant_op.constant([2.0], dtype=dtype)))) 657 self.assertAllCloseAccordingToType([0.0, 2.0], self.evaluate(v)) 658 659 @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64) 660 @test_util.run_in_graph_and_eager_modes 661 def testScatterMulVariableMethod(self, dtype): 662 v = resource_variable_ops.ResourceVariable([0.0, 4.0], 663 name="mul", 664 dtype=dtype) 665 self.evaluate(variables.global_variables_initializer()) 666 self.evaluate( 667 v.scatter_mul( 668 indexed_slices.IndexedSlices( 669 indices=[1], values=constant_op.constant([3.0], dtype=dtype)))) 670 self.assertAllCloseAccordingToType([0.0, 12.0], self.evaluate(v)) 671 672 @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64) 673 @test_util.run_in_graph_and_eager_modes 674 def testScatterDivVariableMethod(self, dtype): 675 v = resource_variable_ops.ResourceVariable([0.0, 6.0], 676 name="div", 677 dtype=dtype) 678 self.evaluate(variables.global_variables_initializer()) 679 self.evaluate( 680 v.scatter_div( 681 indexed_slices.IndexedSlices( 682 indices=[1], values=constant_op.constant([2.0], dtype=dtype)))) 683 self.assertAllCloseAccordingToType([0.0, 3.0], self.evaluate(v)) 684 685 @parameterized.parameters(dtypes.float16, dtypes.float32, dtypes.float64) 686 @test_util.run_in_graph_and_eager_modes 687 def testScatterUpdateVariableMethod(self, dtype): 688 v = resource_variable_ops.ResourceVariable([0.0, 6.0], 689 name="update", 690 dtype=dtype) 691 self.evaluate(variables.global_variables_initializer()) 692 self.evaluate( 693 v.scatter_update( 694 indexed_slices.IndexedSlices( 695 indices=[1], values=constant_op.constant([3.0], dtype=dtype)))) 696 self.assertAllCloseAccordingToType([0.0, 3.0], self.evaluate(v)) 697 698 @test_util.run_deprecated_v1 699 def testScatterUpdateString(self): 700 handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1]) 701 self.evaluate(resource_variable_ops.assign_variable_op( 702 handle, constant_op.constant([["a"]], dtype=dtypes.string))) 703 self.evaluate(resource_variable_ops.resource_scatter_update( 704 handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) 705 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) 706 self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), 707 compat.as_bytes("b")) 708 709 @test_util.run_deprecated_v1 710 def testScatterUpdateStringScalar(self): 711 handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1]) 712 self.evaluate( 713 resource_variable_ops.assign_variable_op(handle, 714 constant_op.constant( 715 [["a"]], 716 dtype=dtypes.string))) 717 self.evaluate( 718 resource_variable_ops.resource_scatter_update(handle, [0], 719 constant_op.constant( 720 "b", 721 dtype=dtypes.string))) 722 read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) 723 self.assertEqual( 724 compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b")) 725 726 # TODO(alive): get this to work in Eager mode. 727 def testGPU(self): 728 with test_util.use_gpu(): 729 abc = variable_scope.get_variable( 730 "abc", 731 shape=[1], 732 initializer=init_ops.ones_initializer(), 733 use_resource=True) 734 735 self.evaluate(variables.global_variables_initializer()) 736 self.assertEqual( 737 self.evaluate( 738 resource_variable_ops.var_is_initialized_op(abc.handle)), 739 True) 740 741 def testScatterBool(self): 742 with context.eager_mode(): 743 ref = resource_variable_ops.ResourceVariable( 744 [False, True, False], trainable=False) 745 indices = math_ops.range(3) 746 updates = constant_op.constant([True, True, True]) 747 state_ops.scatter_update(ref, indices, updates) 748 self.assertAllEqual(ref.read_value(), [True, True, True]) 749 750 @test_util.run_in_graph_and_eager_modes 751 def testConstraintArg(self): 752 constraint = lambda x: x 753 v = resource_variable_ops.ResourceVariable( 754 initial_value=lambda: 1, constraint=constraint, name="var0") 755 self.assertEqual(v.constraint, constraint) 756 757 constraint = 0 758 with self.assertRaises(ValueError): 759 v = resource_variable_ops.ResourceVariable( 760 initial_value=lambda: 1, constraint=constraint, name="var1") 761 762 # TODO(alive): how should this work in Eager mode? 763 @test_util.run_deprecated_v1 764 def testInitFn(self): 765 with self.cached_session(): 766 v = resource_variable_ops.ResourceVariable( 767 initial_value=lambda: 1, dtype=dtypes.float32) 768 self.assertEqual(v.handle.op.colocation_groups(), 769 v.initializer.inputs[1].op.colocation_groups()) 770 771 def testCountUpTo(self): 772 with context.eager_mode(): 773 v = resource_variable_ops.ResourceVariable(0, name="upto") 774 self.assertAllEqual(v.count_up_to(1), 0) 775 with self.assertRaises(errors.OutOfRangeError): 776 v.count_up_to(1) 777 778 def testCountUpToFunction(self): 779 with context.eager_mode(): 780 v = resource_variable_ops.ResourceVariable(0, name="upto") 781 self.assertAllEqual(state_ops.count_up_to(v, 1), 0) 782 with self.assertRaises(errors.OutOfRangeError): 783 state_ops.count_up_to(v, 1) 784 785 @test_util.run_in_graph_and_eager_modes 786 def testInitFnDtype(self): 787 v = resource_variable_ops.ResourceVariable( 788 initial_value=lambda: 1, dtype=dtypes.float32, name="var0") 789 self.assertEqual(dtypes.float32, v.value().dtype) 790 791 @test_util.run_in_graph_and_eager_modes 792 def testInitFnNoDtype(self): 793 v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, 794 name="var2") 795 self.assertEqual(dtypes.int32, v.value().dtype) 796 797 @test_util.run_in_graph_and_eager_modes 798 def testInitializeAllVariables(self): 799 v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32, 800 name="var0") 801 self.evaluate(variables.global_variables_initializer()) 802 self.assertEqual(1.0, self.evaluate(v.value())) 803 804 @test_util.run_in_graph_and_eager_modes 805 def testOperatorOverload(self): 806 v = resource_variable_ops.ResourceVariable(1.0, name="var0") 807 self.evaluate(variables.global_variables_initializer()) 808 self.assertEqual(2.0, self.evaluate(v + v)) 809 810 @test_util.run_in_graph_and_eager_modes 811 def testAssignMethod(self): 812 v = resource_variable_ops.ResourceVariable(1.0, name="var0") 813 self.evaluate(variables.global_variables_initializer()) 814 self.evaluate(v.assign(2.0)) 815 self.assertEqual(2.0, self.evaluate(v.value())) 816 817 # Tests for the 'read_value' argument: 818 assign_with_read = v.assign(3.0, read_value=True) 819 self.assertEqual(3.0, self.evaluate(assign_with_read)) 820 assign_without_read = v.assign(4.0, read_value=False) 821 if context.executing_eagerly(): 822 self.assertIsNone(assign_without_read) 823 else: 824 self.assertIsInstance(assign_without_read, ops.Operation) 825 self.evaluate(assign_without_read) 826 self.assertEqual(4.0, self.evaluate(v.value())) 827 828 def testAssignRuntimeShapeCheck(self): 829 with forward_compat.forward_compatibility_horizon(2022, 3, 30): 830 v = resource_variable_ops.ResourceVariable([1.0, 1.0], name="var0") 831 832 @def_function.function 833 def f(shape): 834 t = array_ops.zeros(shape) 835 v.assign(t) 836 837 with self.assertRaises((errors.InvalidArgumentError, ValueError)): 838 f(constant_op.constant([3])) 839 840 @test_util.run_in_graph_and_eager_modes 841 def testLoad(self): 842 v = resource_variable_ops.ResourceVariable(1.0, name="var0") 843 self.evaluate(variables.global_variables_initializer()) 844 v.load(2.0) 845 self.assertEqual(2.0, self.evaluate(v.value())) 846 847 def testShapePassedToGradient(self): 848 with ops.Graph().as_default(): 849 @custom_gradient.custom_gradient 850 def differentiable_scatter_update(handle, indices, values): 851 with ops.control_dependencies([ 852 resource_variable_ops.resource_scatter_update( 853 handle, indices, values)]): 854 new_handle = array_ops.identity(handle) 855 856 def grad(dresult): 857 self.assertIsNotNone( 858 tensor_util.constant_value(dresult.dense_shape)) 859 return [dresult, None, None] 860 861 return new_handle, grad 862 863 var = variable_scope.get_variable( 864 "foo", shape=[20], initializer=init_ops.zeros_initializer, 865 dtype=dtypes.float64, use_resource=True) 866 867 indices = math_ops.range(10) 868 updates = math_ops.range(9, -1, -1, dtype=dtypes.float64) 869 new_handle = differentiable_scatter_update(var.handle, indices, updates) 870 gathered = resource_variable_ops.resource_gather( 871 new_handle, indices, dtype=var.dtype) 872 gradients_impl.gradients([gathered], [updates]) 873 874 def testToFromProtoCachedValue(self): 875 with ops.Graph().as_default(): 876 v_def = resource_variable_ops.ResourceVariable( 877 initial_value=constant_op.constant(3.0)).to_proto() 878 v_prime = resource_variable_ops.ResourceVariable(variable_def=v_def) 879 self.assertIsNone(getattr(v_prime, "_cached_value", None)) 880 881 other_v_def = resource_variable_ops.ResourceVariable( 882 caching_device="cpu:0", 883 initial_value=constant_op.constant(3.0)).to_proto() 884 other_v_prime = resource_variable_ops.ResourceVariable( 885 variable_def=other_v_def) 886 self.assertIsNotNone(other_v_prime._cached_value) 887 888 def testVariableDefInitializedInstances(self): 889 with ops.Graph().as_default(), self.cached_session(): 890 v_def = resource_variable_ops.ResourceVariable( 891 initial_value=constant_op.constant(3.0)).to_proto() 892 893 with ops.Graph().as_default(), self.cached_session(): 894 # v describes a VariableDef-based variable without an initial value. 895 v = resource_variable_ops.ResourceVariable(variable_def=v_def) 896 self.assertEqual(3.0, self.evaluate(v.initialized_value())) 897 898 # initialized_value should not rerun the initializer_op if the variable 899 # has already been initialized elsewhere. 900 self.evaluate(v.assign(1.0)) 901 self.assertEqual(1.0, v.initialized_value().eval()) 902 903 v_def.ClearField("initial_value_name") 904 with ops.Graph().as_default(), self.cached_session(): 905 # Restoring a legacy VariableDef proto that does not have 906 # initial_value_name set should still work. 907 v = resource_variable_ops.ResourceVariable(variable_def=v_def) 908 # We should also be able to re-export the variable to a new meta graph. 909 self.assertProtoEquals(v_def, v.to_proto()) 910 # But attempts to use initialized_value will result in errors. 911 with self.assertRaises(ValueError): 912 self.evaluate(v.initialized_value()) 913 914 def testTrainableInProto(self): 915 with ops.Graph().as_default(): 916 non_trainable_variable = resource_variable_ops.ResourceVariable( 917 trainable=False, 918 initial_value=constant_op.constant(10.0)) 919 self.assertEqual( 920 False, 921 resource_variable_ops.ResourceVariable( 922 variable_def=non_trainable_variable.to_proto()) 923 .trainable) 924 trainable_variable = resource_variable_ops.ResourceVariable( 925 trainable=True, 926 initial_value=constant_op.constant(10.0)) 927 self.assertEqual( 928 True, 929 resource_variable_ops.ResourceVariable( 930 variable_def=trainable_variable.to_proto()) 931 .trainable) 932 933 @test_util.run_in_graph_and_eager_modes 934 def testSparseRead(self): 935 init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) 936 v = resource_variable_ops.ResourceVariable( 937 constant_op.constant(init_value, dtype=dtypes.int32), name="var3") 938 self.evaluate(variables.global_variables_initializer()) 939 940 value = self.evaluate(v.sparse_read([0, 3, 1, 2])) 941 self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value) 942 943 @test_util.run_in_graph_and_eager_modes 944 def testGatherNd(self): 945 init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) 946 v = resource_variable_ops.ResourceVariable( 947 constant_op.constant(init_value, dtype=dtypes.int32), name="var3") 948 self.evaluate(variables.global_variables_initializer()) 949 950 value_op = v.gather_nd([[0, 0], [1, 2], [3, 3]]) 951 self.assertAllEqual([3, 4], value_op.shape) 952 value = self.evaluate(value_op) 953 self.assertAllEqual([[0, 1, 2, 3], [24, 25, 26, 27], [60, 61, 62, 63]], 954 value) 955 956 value_op = v.gather_nd([[0, 0, 0], [1, 2, 3], [3, 3, 3]]) 957 self.assertAllEqual([3], value_op.shape) 958 value = self.evaluate(value_op) 959 self.assertAllEqual([0, 27, 63], value) 960 961 @test_util.run_deprecated_v1 962 def testToFromProto(self): 963 with self.cached_session(): 964 v = resource_variable_ops.ResourceVariable(1.0) 965 self.evaluate(variables.global_variables_initializer()) 966 967 w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto()) 968 self.assertEqual(2, math_ops.add(w, 1).eval()) 969 970 self.assertEqual(v._handle, w._handle) 971 self.assertEqual(v._graph_element, w._graph_element) 972 973 @test_util.run_in_graph_and_eager_modes 974 def testAssignAddMethod(self): 975 v = resource_variable_ops.ResourceVariable(1.0, name="var0") 976 self.evaluate(variables.global_variables_initializer()) 977 self.evaluate(v.assign_add(1.0)) 978 self.assertEqual(2.0, self.evaluate(v.value())) 979 980 # Tests for the 'read_value' argument: 981 assign_with_read = v.assign_add(1.0, read_value=True) 982 self.assertEqual(3.0, self.evaluate(assign_with_read)) 983 assign_without_read = v.assign_add(1.0, read_value=False) 984 if context.executing_eagerly(): 985 self.assertIsNone(assign_without_read) 986 else: 987 self.assertIsInstance(assign_without_read, ops.Operation) 988 self.evaluate(assign_without_read) 989 self.assertEqual(4.0, self.evaluate(v.value())) 990 991 @test_util.run_in_graph_and_eager_modes 992 def testAssignSubMethod(self): 993 v = resource_variable_ops.ResourceVariable(3.0, name="var0") 994 self.evaluate(variables.global_variables_initializer()) 995 self.evaluate(v.assign_sub(1.0)) 996 self.assertEqual(2.0, self.evaluate(v.value())) 997 998 # Tests for the 'read_value' argument: 999 assign_with_read = v.assign_sub(1.0, read_value=True) 1000 self.assertEqual(1.0, self.evaluate(assign_with_read)) 1001 assign_without_read = v.assign_sub(1.0, read_value=False) 1002 if context.executing_eagerly(): 1003 self.assertIsNone(assign_without_read) 1004 else: 1005 self.assertIsInstance(assign_without_read, ops.Operation) 1006 self.evaluate(assign_without_read) 1007 self.assertEqual(0.0, self.evaluate(v.value())) 1008 1009 @test_util.run_in_graph_and_eager_modes 1010 @test_util.run_v1_only("b/120545219") 1011 def testDestroyResource(self): 1012 v = resource_variable_ops.ResourceVariable(3.0, name="var0") 1013 self.evaluate(variables.global_variables_initializer()) 1014 self.assertEqual(3.0, self.evaluate(v.value())) 1015 self.evaluate(resource_variable_ops.destroy_resource_op(v.handle)) 1016 if context.executing_eagerly(): 1017 # eager mode creates ref-counting variable handles unaffected by 1018 # DestroyResourceOp. 1019 self.assertEqual(3.0, self.evaluate(v.value())) 1020 else: 1021 with self.assertRaises(errors.FailedPreconditionError): 1022 self.evaluate(v.value()) 1023 # Handle to a resource not actually created. 1024 handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[]) 1025 # Should raise no exception 1026 self.evaluate(resource_variable_ops.destroy_resource_op( 1027 handle, ignore_lookup_error=True)) 1028 1029 @test_util.run_deprecated_v1 1030 def testAssignDifferentShapes(self): 1031 with self.cached_session() as sess, variable_scope.variable_scope( 1032 "foo", use_resource=True): 1033 var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32) 1034 placeholder = array_ops.placeholder(dtypes.float32) 1035 assign = var.assign(placeholder) 1036 sess.run( 1037 [assign], 1038 feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)}) 1039 1040 def testAssignDifferentShapesEagerNotAllowed(self): 1041 with context.eager_mode(): 1042 with variable_scope.variable_scope("foo"): 1043 var = variable_scope.get_variable("x", shape=[1, 1], 1044 dtype=dtypes.float32) 1045 with self.assertRaisesRegex(ValueError, 1046 "shape.*and.*are incompatible"): 1047 assign = var.assign(np.zeros(shape=[2, 2])) 1048 self.evaluate(assign) 1049 1050 @test_util.disable_xla("XLA doesn't allow changing shape at assignment, as " 1051 "dictated by tf2xla/xla_resource.cc:SetTypeAndShape") 1052 @test_util.run_in_graph_and_eager_modes 1053 def testAssignDifferentShapesAllowed(self): 1054 var = resource_variable_ops.ResourceVariable( 1055 initial_value=np.zeros(shape=[1, 1]), 1056 shape=tensor_shape.TensorShape(None)) 1057 self.evaluate(variables.global_variables_initializer()) 1058 self.assertAllEqual(np.zeros(shape=[1, 1]), var.read_value()) 1059 self.evaluate(var.assign(np.zeros(shape=[2, 2]))) 1060 self.assertAllEqual(np.zeros(shape=[2, 2]), var.read_value()) 1061 1062 @test_util.run_in_graph_and_eager_modes 1063 def testAssignReturnsVariable(self): 1064 var = resource_variable_ops.ResourceVariable(1.) 1065 self.evaluate(variables.global_variables_initializer()) 1066 assigned = var.assign(2.) 1067 self.assertIsInstance(assigned, resource_variable_ops.BaseResourceVariable) 1068 assigned = assigned.assign(3.) 1069 self.assertEqual(self.evaluate(assigned), 3.) 1070 self.assertEqual(self.evaluate(var), 3.) 1071 1072 self.assertEqual(self.evaluate(var.assign_add(1.).assign_add(1.)), 5) 1073 self.assertEqual(self.evaluate(var.assign_sub(1.).assign_sub(1.)), 3) 1074 1075 var = resource_variable_ops.ResourceVariable([1., 2.]) 1076 self.evaluate(variables.global_variables_initializer()) 1077 slices = indexed_slices.IndexedSlices(indices=[1], values=[2]) 1078 def assert_eq(tensor, vals): 1079 self.assertAllEqual(self.evaluate(tensor), vals) 1080 assert_eq(var.scatter_add(slices).scatter_add(slices), [1., 6.]) 1081 assert_eq(var.scatter_sub(slices).scatter_sub(slices), [1., 2.]) 1082 slices2 = indexed_slices.IndexedSlices(indices=[0], values=[3]) 1083 assert_eq(var.scatter_max(slices2).scatter_add(slices), [3., 4.]) 1084 assert_eq(var.scatter_add(slices).scatter_min(slices), [3., 2.]) 1085 assert_eq(var.scatter_mul(slices).scatter_mul(slices), [3., 8.]) 1086 assert_eq(var.scatter_div(slices).scatter_div(slices), [3., 2.]) 1087 assert_eq( 1088 var.scatter_nd_update([[1]], [4.]).scatter_nd_add([[0]], [2.]) 1089 .scatter_nd_sub([[1]], [3]), 1090 [5., 1.]) 1091 assert_eq(var, [5., 1.]) 1092 1093 batch_var = resource_variable_ops.ResourceVariable(array_ops.ones((2, 2))) 1094 self.evaluate(variables.global_variables_initializer()) 1095 batch_slices1 = indexed_slices.IndexedSlices( 1096 indices=[[1], [0]], values=[[2], [2]]) 1097 batch_slices2 = indexed_slices.IndexedSlices( 1098 indices=[[1], [1]], values=[[3], [3]]) 1099 assert_eq( 1100 batch_var.batch_scatter_update(batch_slices1) 1101 .batch_scatter_update(batch_slices2), 1102 [[1, 3], [2, 3]]) 1103 1104 @test_util.run_in_graph_and_eager_modes 1105 def testInitValueWrongShape(self): 1106 with self.assertRaisesWithPredicateMatch( 1107 ValueError, r"not compatible with"): 1108 var = resource_variable_ops.ResourceVariable( 1109 initial_value=np.zeros(shape=[3]), 1110 shape=[4]) 1111 self.evaluate(variables.global_variables_initializer()) 1112 self.evaluate(var.read_value()) 1113 1114 @test_util.run_deprecated_v1 1115 def testDtypeAfterFromProto(self): 1116 v = resource_variable_ops.ResourceVariable(2.0) 1117 w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto()) 1118 self.assertIsInstance(w.dtype, dtypes.DType) 1119 self.assertEqual(v.dtype, w.dtype) 1120 1121 # TODO(alive): get caching to work in eager mode. 1122 @test_util.run_deprecated_v1 1123 def testCachingDevice(self): 1124 with ops.device("/job:server/task:1"): 1125 v = resource_variable_ops.ResourceVariable( 1126 2.0, caching_device="/job:localhost") 1127 self.assertEqual("/job:localhost", v.value().device) 1128 with self.assertRaises(ValueError): 1129 _ = v.value().op.get_attr("_class") 1130 1131 with ops.colocate_with(v.op): 1132 w = resource_variable_ops.ResourceVariable( 1133 2.0, caching_device="/job:localhost") 1134 self.assertEqual("/job:localhost", w.value().device) 1135 with self.assertRaises(ValueError): 1136 _ = w.value().op.get_attr("_class") 1137 1138 @test_util.run_deprecated_v1 1139 def testSharedName(self): 1140 with self.cached_session(): 1141 v = resource_variable_ops.ResourceVariable(300.0, name="var4") 1142 self.evaluate(variables.global_variables_initializer()) 1143 1144 w = _eager_safe_var_handle_op( 1145 dtype=v.dtype.base_dtype, 1146 shape=v.get_shape(), 1147 shared_name="var4", 1148 # Needed in Eager since we get a unique container name by default. 1149 container=ops.get_default_graph()._container) 1150 w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) 1151 self.assertEqual(300.0, self.evaluate(w_read)) 1152 1153 x = _eager_safe_var_handle_op( 1154 dtype=v.dtype.base_dtype, 1155 shape=v.get_shape(), 1156 shared_name="var5", 1157 container=ops.get_default_graph()._container) 1158 with self.assertRaisesOpError( 1159 "(Resource .*/var5/.* does not exist|uninitialized)"): 1160 resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval() 1161 1162 @test_util.run_deprecated_v1 1163 def testSharedNameWithNamescope(self): 1164 with self.cached_session(): 1165 with ops.name_scope("foo"): 1166 v = resource_variable_ops.ResourceVariable(300.0, name="var6") 1167 self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access 1168 self.assertEqual("foo/var6:0", v.name) 1169 self.evaluate(variables.global_variables_initializer()) 1170 1171 w = _eager_safe_var_handle_op( 1172 dtype=v.dtype.base_dtype, 1173 shape=v.get_shape(), 1174 shared_name="foo/var6", 1175 # Needed in Eager since we get a unique container name by default. 1176 container=ops.get_default_graph()._container) 1177 w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) 1178 self.assertEqual(300.0, self.evaluate(w_read)) 1179 1180 @test_util.run_in_graph_and_eager_modes 1181 def testShape(self): 1182 v = resource_variable_ops.ResourceVariable( 1183 name="var4", initial_value=array_ops.ones(shape=[10, 20, 35])) 1184 self.assertEqual("(10, 20, 35)", str(v.shape)) 1185 self.assertEqual("(10, 20, 35)", str(v.get_shape())) 1186 self.assertEqual("(10, 20, 35)", str(v.value().shape)) 1187 self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape)) 1188 if not context.executing_eagerly(): 1189 self.assertEqual( 1190 "<unknown>", 1191 str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape)) 1192 1193 @test_util.run_deprecated_v1 1194 def testSetInitialValue(self): 1195 with self.cached_session(): 1196 # Initialize variable with a value different from the initial value passed 1197 # in the constructor. 1198 v = resource_variable_ops.ResourceVariable(2.0) 1199 v.initializer.run(feed_dict={v.initial_value: 3.0}) 1200 self.assertEqual(3.0, v.value().eval()) 1201 1202 @test_util.run_v1_only("b/120545219") 1203 def testControlFlowInitialization(self): 1204 """Expects an error if an initializer is in a control-flow scope.""" 1205 1206 def cond(i, _): 1207 return i < 10 1208 1209 def body(i, _): 1210 zero = array_ops.zeros([], dtype=dtypes.int32) 1211 v = resource_variable_ops.ResourceVariable(initial_value=zero) 1212 return (i + 1, v.read_value()) 1213 1214 with self.assertRaisesRegex(ValueError, "initial_value"): 1215 control_flow_ops.while_loop(cond, body, [0, 0]) 1216 1217 def testVariableEager(self): 1218 with context.eager_mode(): 1219 init = array_ops.ones(shape=[10, 20, 35], dtype=dtypes.int32) 1220 constraint = lambda x: x 1221 with ops.name_scope("foo", skip_on_eager=False): 1222 v = resource_variable_ops.ResourceVariable( 1223 name="var7", 1224 initial_value=init, 1225 caching_device="cpu:0", 1226 constraint=constraint) 1227 # Test properties 1228 self.assertEqual(dtypes.int32, v.dtype) 1229 self.assertEqual("foo/var7:0", v.name) 1230 self.assertAllEqual([10, 20, 35], v.shape.as_list()) 1231 self.assertIsInstance(v.handle, ops.EagerTensor) 1232 self.assertEqual(constraint, v.constraint) 1233 self.assertAllEqual(init.numpy(), v.read_value().numpy()) 1234 self.assertAllEqual(init.numpy(), v.value().numpy()) 1235 1236 # Callable init. 1237 callable_init = lambda: init * 2 1238 v2 = resource_variable_ops.ResourceVariable( 1239 initial_value=callable_init, name="var7") 1240 self.assertEqual("var7:0", v2.name) 1241 self.assertAllEqual(2 * init.numpy(), v2.read_value().numpy()) 1242 1243 # Test assign_add. 1244 new_v2_val = v2.assign_add(v.read_value()) 1245 self.assertAllEqual(v.read_value().numpy() * 3, new_v2_val.numpy()) 1246 1247 # Test assign_sub. 1248 new_v2_val = v2.assign_sub(v.read_value()) 1249 self.assertAllEqual(v.read_value().numpy() * 2, new_v2_val.numpy()) 1250 1251 # Test assign. 1252 v2.assign(v.read_value()) 1253 self.assertAllEqual(v.read_value().numpy(), v2.read_value().numpy()) 1254 1255 # Test load 1256 v2.load(2 * v.read_value()) 1257 self.assertAllEqual(2 * v.read_value().numpy(), v2.read_value().numpy()) 1258 1259 # Test convert_to_tensor 1260 t = ops.convert_to_tensor(v) 1261 self.assertAllEqual(t.numpy(), v.read_value().numpy()) 1262 1263 # Test operations 1264 self.assertAllEqual((v * 2).numpy(), (v + v).numpy()) 1265 1266 def testNumpyDotArray(self): 1267 with context.eager_mode(): 1268 # Scalars use a separate code path. 1269 v1 = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, 1270 name="v1") 1271 self.assertEqual(1, np.array(v1)) 1272 1273 v2 = resource_variable_ops.ResourceVariable(initial_value=lambda: [1, 2], 1274 name="v2") 1275 self.assertAllEqual(v2.read_value().numpy(), np.array(v2)) 1276 self.assertAllEqual([1, 2], np.array(v2)) 1277 1278 def testContainerEager(self): 1279 with context.eager_mode(): 1280 v1 = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, 1281 name="same") 1282 with ops.container("different"): 1283 v2 = resource_variable_ops.ResourceVariable(initial_value=lambda: 0, 1284 name="same") 1285 v2.assign(2) 1286 self.assertEqual(1, v1.read_value().numpy()) 1287 self.assertEqual(2, v2.read_value().numpy()) 1288 1289 def testDestruction(self): 1290 with context.eager_mode(): 1291 var = resource_variable_ops.ResourceVariable(initial_value=1.0, 1292 name="var8") 1293 var_handle = test_ops.make_weak_resource_handle(var._handle) 1294 del var 1295 with self.assertRaisesRegex(errors.NotFoundError, 1296 r"Resource .* does not exist."): 1297 resource_variable_ops.destroy_resource_op(var_handle, 1298 ignore_lookup_error=False) 1299 1300 def testScatterUpdate(self): 1301 with context.eager_mode(): 1302 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update") 1303 state_ops.scatter_update(v, [1], [3.0]) 1304 self.assertAllEqual([1.0, 3.0], v.numpy()) 1305 1306 def testScatterAddStateOps(self): 1307 with context.eager_mode(): 1308 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="add") 1309 state_ops.scatter_add(v, [1], [3]) 1310 self.assertAllEqual([1.0, 5.0], v.numpy()) 1311 1312 def testScatterSubStateOps(self): 1313 with context.eager_mode(): 1314 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="sub") 1315 state_ops.scatter_sub(v, [1], [3]) 1316 self.assertAllEqual([1.0, -1.0], v.numpy()) 1317 1318 def testScatterUpdateVariant(self): 1319 with context.eager_mode(): 1320 v = resource_variable_ops.ResourceVariable([ 1321 list_ops.empty_tensor_list( 1322 element_dtype=dtypes.float32, element_shape=[]) 1323 ]) 1324 v.scatter_update( 1325 indexed_slices.IndexedSlices( 1326 list_ops.tensor_list_from_tensor([1., 2.], element_shape=[]), 0)) 1327 self.assertAllEqual( 1328 list_ops.tensor_list_get_item(v[0], 0, element_dtype=dtypes.float32), 1329 1.) 1330 1331 def testGroupDoesntForceRead(self): 1332 with ops.Graph().as_default(): 1333 v = resource_variable_ops.ResourceVariable(1.0) 1334 assign = v.assign_add(1.0) 1335 g = control_flow_ops.group([assign]) 1336 self.assertEqual(g.control_inputs[0].type, "AssignAddVariableOp") 1337 1338 def testScatterNdAddStateOps(self): 1339 with context.eager_mode(): 1340 v = resource_variable_ops.ResourceVariable( 1341 [1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.float32, name="add") 1342 indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) 1343 updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) 1344 expected = np.array([1, 13, 3, 14, 14, 6, 7, 20]) 1345 state_ops.scatter_nd_add(v, indices, updates) 1346 self.assertAllClose(expected, v.numpy()) 1347 1348 @test_util.run_in_graph_and_eager_modes 1349 def testUnreadVariableInsideFunction(self): 1350 v = resource_variable_ops.ResourceVariable(1.0) 1351 1352 @def_function.function 1353 def assign(): 1354 v.assign(1.0) 1355 1356 graph = assign.get_concrete_function().graph 1357 self.assertTrue(all(x.type != "ReadVariableOp" 1358 for x in graph.get_operations())) 1359 1360 def testScatterNdSubStateOps(self): 1361 with context.eager_mode(): 1362 v = resource_variable_ops.ResourceVariable( 1363 [1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.float32, name="sub") 1364 indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) 1365 updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) 1366 expected = np.array([1, -9, 3, -6, -4, 6, 7, -4]) 1367 state_ops.scatter_nd_sub(v, indices, updates) 1368 self.assertAllClose(expected, v.numpy()) 1369 1370 def testScatterUpdateCast(self): 1371 with context.eager_mode(): 1372 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update") 1373 state_ops.scatter_update(v, [1], [3]) 1374 self.assertAllEqual([1.0, 3.0], v.numpy()) 1375 1376 @test_util.run_in_graph_and_eager_modes 1377 def testScatterUpdateInvalidArgs(self): 1378 v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") 1379 # The exact error and message differ between graph construction (where the 1380 # error is realized during shape inference at graph construction time), 1381 # eager execution (where the error is realized during kernel execution), 1382 # and XLA auto-clustering execution (where the error is realized in the xla 1383 # op kernel) which is triggered when running in eager op as function mode. 1384 with self.assertRaisesRegex(Exception, r"shape.*2.*3|RET_CHECK failure"): 1385 state_ops.scatter_update(v, [0, 1], [0, 1, 2]) 1386 1387 @test_util.disable_xla("b/208334252") # XLA doesn't have a deterministic impl 1388 def testScatterAddDeterministic(self): 1389 with context.eager_mode(), test_util.deterministic_ops(): 1390 # Normally a nondeterministic codepath occurs when the variable has at 1391 # least 1024 elements. Test that op determinism ensures the op is 1392 # deterministc. 1393 v = resource_variable_ops.ResourceVariable(array_ops.zeros([1024])) 1394 delta = ops.IndexedSlices( 1395 values=np.random.normal(size=(1_000_000,)), 1396 indices=array_ops.zeros((1_000_000,), dtype=np.int32), 1397 dense_shape=(1024,)) 1398 v.scatter_add(delta) 1399 for _ in range(5): 1400 v2 = resource_variable_ops.ResourceVariable(array_ops.zeros([1024])) 1401 v2.scatter_add(delta) 1402 self.assertAllEqual(v, v2) 1403 1404 @test_util.run_in_graph_and_eager_modes 1405 def testAssignIncompatibleShape(self): 1406 v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) 1407 self.evaluate(v.initializer) 1408 pattern = re.compile("shapes must be equal", re.IGNORECASE) 1409 with self.assertRaisesRegex(Exception, pattern): 1410 self.evaluate(v.assign_add(1)) 1411 1412 @test_util.run_in_graph_and_eager_modes 1413 @test_util.run_v1_only("b/120545219") 1414 def testCopyToGraphUninitialized(self): 1415 v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) 1416 copy_to_graph = ops.Graph() 1417 with copy_to_graph.as_default(): # Intentionally testing v1 behavior 1418 copied = resource_variable_ops.copy_to_graph_uninitialized(v) 1419 self.assertEqual(v.name, copied.name) 1420 self.assertIsNone(copied.initializer) 1421 1422 def create_variant_shape_and_type_data(self): 1423 variant_shape_and_type_data = ( 1424 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()) 1425 variant_shape_and_type_data.is_set = True 1426 stored_shape = tensor_shape.TensorShape([None, 4]).as_proto() 1427 stored_dtype = dtypes.float32.as_datatype_enum 1428 # NOTE(ebrevdo): shape_and_type lacks append() in some versions of protobuf. 1429 variant_shape_and_type_data.shape_and_type.extend([ 1430 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 1431 shape=stored_shape, 1432 dtype=stored_dtype, 1433 type=full_type_pb2.FullTypeDef()) 1434 ]) 1435 return variant_shape_and_type_data 1436 1437 @def_function.function 1438 def create_constant_variant(self, value): 1439 value = constant_op.constant( 1440 tensor_pb2.TensorProto( 1441 dtype=dtypes.variant.as_datatype_enum, 1442 tensor_shape=tensor_shape.TensorShape([]).as_proto(), 1443 variant_val=[ 1444 tensor_pb2.VariantTensorDataProto( 1445 # Match registration in variant_op_registry.cc 1446 type_name=b"int", 1447 metadata=np.array(value, dtype=np.int32).tobytes()) 1448 ])) 1449 return value 1450 1451 # TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create 1452 # EagerTensor constants with TensorProto inputs. 1453 @test_util.disable_tfrt("Does not support tf.Const in lowering.") 1454 @test_util.run_in_graph_and_eager_modes() 1455 def testVariantInitializer(self): 1456 variant_shape_and_type_data = self.create_variant_shape_and_type_data() 1457 value = self.create_constant_variant(3) 1458 initializer = array_ops.fill([3], value) 1459 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 1460 initializer, variant_shape_and_type_data, 1461 graph_mode=not context.executing_eagerly()) 1462 v = resource_variable_ops.ResourceVariable(initializer) 1463 read = array_ops.identity(v) 1464 read_variant_shape_and_type = ( 1465 resource_variable_ops.get_eager_safe_handle_data(read)) 1466 self.assertEqual( 1467 read_variant_shape_and_type, variant_shape_and_type_data) 1468 gather = v.sparse_read([0]) 1469 gather_variant_shape_and_type = ( 1470 resource_variable_ops.get_eager_safe_handle_data(gather)) 1471 self.assertEqual( 1472 gather_variant_shape_and_type, variant_shape_and_type_data) 1473 # Make sure initializer runs. 1474 if not context.executing_eagerly(): 1475 self.evaluate(v.initializer) 1476 self.evaluate(read.op) 1477 self.evaluate(gather.op) 1478 1479 @parameterized.parameters([ 1480 # batch_dims=0 (equivalent to tf.gather) 1481 dict( # 2D indices 1482 batch_dims=0, 1483 params=[6, 7, 8, 9], 1484 indices=[[2, 1], [0, 3]], 1485 expected=[[8, 7], [6, 9]]), 1486 dict( # 3D indices 1487 batch_dims=0, 1488 params=[6, 7, 8, 9], 1489 indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]], 1490 expected=[[[9, 7], [8, 6]], [[6, 9], [8, 8]]]), 1491 dict( # 4D indices 1492 batch_dims=0, 1493 params=[8, 9], 1494 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 1495 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 1496 expected=[[[[8, 9], [9, 8]], [[8, 8], [9, 9]]], 1497 [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]), 1498 1499 # batch_dims=indices.shape.ndims - 1 (equivalent to 1500 # tf.compat.v1.batch_gather) 1501 dict( # 2D indices (1 batch dim) 1502 batch_dims=1, 1503 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 1504 indices=[[2, 1], [0, 3]], 1505 expected=[[12, 11], [20, 23]]), 1506 dict( # 3D indices (2 batch dims) 1507 batch_dims=2, 1508 params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]], 1509 indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 1510 expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]), 1511 dict( # 2D indices (1 batch dim) 1512 batch_dims=1, 1513 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 1514 indices=[[2, 1], [0, 3]], 1515 expected=[[12, 11], [20, 23]]), 1516 dict( # 3D indices (2 batch dims) 1517 batch_dims=2, 1518 params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]], 1519 indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 1520 expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]), 1521 1522 # 0 < batch_dims < indices.shape.ndims - 1 1523 dict( # 3D indices (1 batch dim) 1524 batch_dims=1, 1525 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 1526 indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]], 1527 expected=[[[13, 11], [12, 10]], [[20, 23], [22, 22]]]), 1528 dict( # 4D indices (1 batch dim) 1529 batch_dims=1, 1530 params=[[6, 7], [8, 9]], 1531 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 1532 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 1533 expected=[[[[6, 7], [7, 6]], [[6, 6], [7, 7]]], 1534 [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]), 1535 dict( # 4D indices (2 batch dims) 1536 batch_dims=2, 1537 params=[[[2, 3], [4, 5]], [[6, 7], [8, 9]]], 1538 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 1539 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 1540 expected=[[[[2, 3], [3, 2]], [[4, 4], [5, 5]]], 1541 [[[7, 7], [6, 6]], [[8, 9], [9, 8]]]]), 1542 ]) 1543 @test_util.run_in_graph_and_eager_modes 1544 def testGatherWithBatchDims(self, params, indices, batch_dims, expected): 1545 var = resource_variable_ops.ResourceVariable(params, name="var0") 1546 with ops.control_dependencies([var.initializer]): 1547 result = resource_variable_ops.resource_gather( 1548 var.handle, indices, dtype=var.dtype, batch_dims=batch_dims) 1549 self.assertAllEqual(expected, result) 1550 1551 @parameterized.parameters([ 1552 dict( 1553 params_shape=[2, 3, 4, 5, 6, 7], 1554 indices_shape=[2, 3, 8, 9, 10], 1555 batch_dims=0, 1556 output_shape=[2, 3, 8, 9, 10, 3, 4, 5, 6, 7] 1557 # = indices.shape + params.shape[1:] 1558 ), 1559 dict( 1560 params_shape=[2, 3, 4, 5, 6, 7], 1561 indices_shape=[2, 3, 8, 9, 10], 1562 batch_dims=1, 1563 output_shape=[2, 3, 8, 9, 10, 4, 5, 6, 7] 1564 # = params.shape[:1] + indices.shape[1:] + params.shape[2:] 1565 ), 1566 dict( 1567 params_shape=[2, 3, 4, 5, 6, 7], 1568 indices_shape=[2, 3, 8, 9, 10], 1569 batch_dims=2, 1570 output_shape=[2, 3, 8, 9, 10, 5, 6, 7] 1571 # = params.shape[:2] + indices.shape[2:] + params.shape[3:] 1572 ), 1573 dict( 1574 params_shape=[2, 3, 4, 5, 6, 7], 1575 indices_shape=[2, 3, 4, 9, 10], 1576 batch_dims=3, 1577 output_shape=[2, 3, 4, 9, 10, 6, 7] 1578 # = params.shape[:3] + indices.shape[3:] + params.shape[4:] 1579 ), 1580 dict( 1581 params_shape=[2, 3, 4, 5, 6, 7], 1582 indices_shape=[2, 3, 4, 5, 10], 1583 batch_dims=4, 1584 output_shape=[2, 3, 4, 5, 10, 7] 1585 # = params.shape[:4] + indices.shape[4:] + params.shape[5:] 1586 ), 1587 ]) 1588 @test_util.run_in_graph_and_eager_modes 1589 def testGatherWithBatchDimsMatchesTensor(self, params_shape, indices_shape, 1590 batch_dims, output_shape): 1591 """Checks that gather with batch_dims returns the correct shape.""" 1592 # Generate a `params` tensor with the indicated shape. 1593 params_size = np.prod(params_shape) 1594 params = np.reshape(np.arange(params_size, dtype=np.int32), params_shape) 1595 1596 # Generate an `indices` tensor with the indicated shape, where each index 1597 # is within the appropriate range. 1598 indices_size = np.prod(indices_shape) 1599 indices = np.reshape(np.arange(indices_size, dtype=np.int32), indices_shape) 1600 indices = indices % params_shape[batch_dims] 1601 1602 var = resource_variable_ops.ResourceVariable(params, name="var0") 1603 with ops.control_dependencies([var.initializer]): 1604 expected = array_ops.gather( 1605 var.read_value(), indices, batch_dims=batch_dims) 1606 result = resource_variable_ops.resource_gather( 1607 var.handle, indices, dtype=var.dtype, batch_dims=batch_dims) 1608 1609 self.assertAllEqual(output_shape, result.shape.as_list()) 1610 self.assertAllEqual(expected, result) 1611 1612 @parameterized.parameters([ 1613 dict(dtype=dtypes.bool), 1614 dict(dtype=dtypes.int64), 1615 dict(dtype=dtypes.half), 1616 dict(dtype=dtypes.float32), 1617 dict(dtype=dtypes.double), 1618 ]) 1619 @test_util.run_gpu_only 1620 @test_util.run_in_graph_and_eager_modes 1621 def testGatherWithDTypes(self, dtype): 1622 if dtype == dtypes.bool: 1623 params = constant_op.constant([False, True, False, True]) 1624 expected = constant_op.constant([[False, True], [False, True]]) 1625 else: 1626 params = constant_op.constant([6, 7, 8, 9], dtype=dtype) 1627 expected = constant_op.constant([[8, 7], [6, 9]], dtype=dtype) 1628 indices = constant_op.constant([[2, 1], [0, 3]]) 1629 var = resource_variable_ops.ResourceVariable(params, name="var0") 1630 with ops.control_dependencies([var.initializer]): 1631 result = resource_variable_ops.resource_gather( 1632 var.handle, indices, dtype=dtype) 1633 self.assertAllEqual(expected, result) 1634 1635 @test_util.run_v2_only 1636 def testUninitializedVariableMemoryUsage(self): 1637 if test_util.is_gpu_available(): 1638 # TODO(allenl): Investigate possible GPU-specific memory leaks 1639 self.skipTest("Disabled when a GPU is available") 1640 # TODO(kkb): Python memory checker complains continuous `weakref` 1641 # allocations, investigate. 1642 if memory_checker.CppMemoryChecker is None: 1643 self.skipTest("Requires the C++ memory checker") 1644 1645 def _create_and_delete_variable(): 1646 resource_variable_ops.UninitializedVariable( 1647 shape=[100, 100], 1648 dtype=dtypes.float32) 1649 1650 _create_and_delete_variable() 1651 checker = memory_checker.CppMemoryChecker( 1652 "ResourceVariableOps.testUninitializedVariableMemoryUsage") 1653 for _ in range(2): 1654 _create_and_delete_variable() 1655 checker.record_snapshot() 1656 checker.stop() 1657 checker.report() 1658 checker.assert_no_leak_if_all_possibly_except_one() 1659 1660 @test_util.run_v2_only 1661 def testIterateVariable(self): 1662 v = variables.Variable([1., 2.]) 1663 self.assertAllClose([1., 2.], list(iter(v))) 1664 1665 1666if __name__ == "__main__": 1667 test.main() 1668