1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed variables library.""" 16 17import copy 18import os 19 20from absl.testing import parameterized 21from tensorflow.python.checkpoint import checkpoint as trackable_utils 22from tensorflow.python.distribute import collective_all_reduce_strategy 23from tensorflow.python.distribute import combinations 24from tensorflow.python.distribute import distribute_lib 25from tensorflow.python.distribute import distribute_utils 26from tensorflow.python.distribute import packed_distributed_variable as packed 27from tensorflow.python.distribute import parameter_server_strategy 28from tensorflow.python.distribute import ps_values 29from tensorflow.python.distribute import strategy_combinations 30from tensorflow.python.distribute import test_util as ds_test_util 31from tensorflow.python.distribute import tpu_strategy 32from tensorflow.python.distribute import values as values_lib 33from tensorflow.python.eager import context 34from tensorflow.python.eager import def_function 35from tensorflow.python.eager import test 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import indexed_slices 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import check_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import variable_scope 46from tensorflow.python.ops import variables as variables_lib 47from tensorflow.python.saved_model import save 48from tensorflow.python.saved_model import save_context 49from tensorflow.python.saved_model import save_options 50from tensorflow.python.types import core 51 52 53def _device_str(d): 54 return "/device:GPU:" + str(d) 55 56 57def _nested_value(d): 58 return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) 59 60 61def mirrored_and_tpu_strategy_combinations(): 62 return combinations.combine( 63 distribution=[ 64 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 65 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 66 strategy_combinations.tpu_strategy, 67 strategy_combinations.tpu_strategy_packed_var, 68 ], 69 mode=["graph", "eager"]) 70 71 72@combinations.generate( 73 combinations.combine( 74 distribution=[ 75 strategy_combinations.mirrored_strategy_with_one_cpu, 76 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 77 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 78 strategy_combinations.tpu_strategy, 79 strategy_combinations.tpu_strategy_packed_var, 80 strategy_combinations.tpu_strategy_spmd, 81 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 82 strategy_combinations.multi_worker_mirrored_2x1_cpu, 83 strategy_combinations.multi_worker_mirrored_2x1_gpu, 84 strategy_combinations.multi_worker_mirrored_2x2_gpu, 85 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 86 ], 87 synchronization=[ 88 variables_lib.VariableSynchronization.ON_READ, 89 variables_lib.VariableSynchronization.ON_WRITE, 90 ], 91 aggregation=[ 92 variables_lib.VariableAggregation.MEAN, 93 variables_lib.VariableAggregation.SUM, 94 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 95 ], 96 mode=["graph", "eager"], 97 use_var_policy=[True, False])) 98class DistributedVariableTest(test.TestCase, parameterized.TestCase): 99 100 def testExtendsVariable(self, distribution, synchronization, aggregation): 101 with distribution.scope(): 102 v = variables_lib.Variable( 103 1., synchronization=synchronization, aggregation=aggregation) 104 self.assertIsInstance(v, variables_lib.Variable) 105 106 def testCheckpointing(self, distribution, synchronization, aggregation, mode): 107 108 if (isinstance(distribution, 109 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 110 and mode == "graph"): 111 self.skipTest("MWMS combinations tests do not work well in graph mode.") 112 113 with distribution.scope(): 114 v = variables_lib.Variable( 115 constant_op.constant([1., 2., 3., 4]), 116 synchronization=synchronization, 117 aggregation=aggregation) 118 119 self.evaluate(v.initializer) 120 before_save = self.evaluate(v.read_value()) 121 122 # Save random weights into checkpoint. 123 checkpoint = trackable_utils.Checkpoint(v=v) 124 prefix = os.path.join(self.get_temp_dir(), "ckpt") 125 with self.test_session(): 126 save_path = checkpoint.save(prefix) 127 128 # Assign inverted value. 129 self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.]))) 130 after_assign = self.evaluate(v.read_value()) 131 self.assertNotAllClose(before_save, after_assign) 132 133 # Restore from the checkpoint. 134 with self.test_session(): 135 checkpoint.restore(save_path).assert_consumed().run_restore_ops() 136 after_restore = self.evaluate(v) 137 self.assertAllClose(before_save, after_restore) 138 139 def testTraceback(self, distribution, synchronization, aggregation): 140 if context.executing_eagerly(): 141 self.skipTest("does not apply to eager") 142 with distribution.scope(): 143 variable_scope.get_variable( 144 name="testVar", 145 initializer=1., 146 use_resource=True, 147 synchronization=synchronization, 148 aggregation=aggregation) 149 with self.assertRaisesRegex(ValueError, 150 "Variable testVar already exists"): 151 variable_scope.get_variable( 152 name="testVar", 153 initializer=1., 154 use_resource=True, 155 synchronization=synchronization, 156 aggregation=aggregation) 157 158 def testSelectReplica(self, distribution, synchronization, aggregation): 159 with distribution.scope(): 160 v = variables_lib.Variable( 161 1., synchronization=synchronization, aggregation=aggregation) 162 self.assertIs(v, distribute_utils.select_replica(0, v)) 163 164 def testIsTensorLike(self, distribution, synchronization, aggregation): 165 if isinstance(distribution.extended, 166 tpu_strategy.TPUExtended) and context.executing_eagerly(): 167 self.skipTest("TPU doesn't support pure eager") 168 169 with distribution.scope(): 170 v = variables_lib.Variable( 171 0., synchronization=synchronization, aggregation=aggregation) 172 # In cross replica context. 173 self.assertIsInstance(v, core.Tensor) 174 # In replica context. 175 distribution.run(lambda v: self.assertIsInstance(v, core.Tensor), args=(v,)) 176 177 def testAssignReturnValueIsTensorLike(self, distribution, synchronization, 178 aggregation): 179 if isinstance(distribution.extended, tpu_strategy.TPUExtended): 180 if context.executing_eagerly(): 181 self.skipTest("TPU doesn't support pure eager") 182 else: 183 self.skipTest("b/152076846") 184 185 with distribution.scope(): 186 v = variables_lib.Variable( 187 0., synchronization=synchronization, aggregation=aggregation) 188 189 def assert_is_tensor_like(v): 190 # We can't use Python literals because they are treated as non-distributed 191 # values is not allowed when aggregation is SUM. See 192 # `cross_device_ops.reduce_non_distributed_value`. 193 delta = array_ops.identity(1.) 194 self.assertIsInstance(v.assign(delta), core.Tensor) 195 self.assertIsInstance(v.assign_sub(delta), core.Tensor) 196 self.assertIsInstance(v.assign_add(delta), core.Tensor) 197 198 # In cross replica context we return a PerReplica which is not Tensor like 199 # all the time yet. 200 if (synchronization == variables_lib.VariableSynchronization.ON_READ and 201 aggregation != variables_lib.VariableAggregation.SUM): 202 assert_is_tensor_like(v) 203 204 # In replica context. 205 distribution.run(assert_is_tensor_like, args=(v,)) 206 207 def testDeepCopy(self, distribution, synchronization, aggregation): 208 if not context.executing_eagerly(): 209 self.skipTest("deepcopy only supported in eager mode") 210 211 with distribution.scope(): 212 v = variables_lib.Variable( 213 0., synchronization=synchronization, aggregation=aggregation) 214 in_dist_copy = copy.deepcopy(v) 215 216 out_dist_copy = copy.deepcopy(v) 217 218 def assert_is_deep_copy(v1, v2): 219 self.assertIsInstance(v2, type(v1)) 220 self.assertEqual(v1.aggregation, v2.aggregation) 221 self.assertEqual(v1.distribute_strategy, v2.distribute_strategy) 222 if isinstance(v1, ps_values.AggregatingVariable): 223 self.assertIsInstance(v2.get(), type(v1.get())) 224 self.assertNotEqual(id(v1.get()), id(v2.get())) 225 else: 226 if v1._policy: 227 self.assertNotEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access 228 else: 229 self.assertEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access 230 self.assertEqual(len(v1.values), len(v2.values)) 231 for (v1v, v2v) in zip(v1.values, v2.values): 232 self.assertEqual(v1v.device, v2v.device) 233 self.assertNotEqual(id(v1v), id(v2v)) 234 self.assertAllEqual( 235 self.evaluate(v1.values), self.evaluate(v2.values)) 236 237 self.evaluate(variables_lib.global_variables_initializer()) 238 if not isinstance(distribution.extended, tpu_strategy.TPUExtended): 239 distribution.run(assert_is_deep_copy, args=(v, in_dist_copy)) 240 distribution.run(assert_is_deep_copy, args=(v, out_dist_copy)) 241 242 def testAssignSignature(self, distribution, synchronization, aggregation): 243 # This test verifies assign*() can be called in the same way as normal 244 # variables. 245 with distribution.scope(): 246 v = variables_lib.Variable( 247 0., synchronization=synchronization, aggregation=aggregation) 248 249 def assign(): 250 one = constant_op.constant(1.) 251 v.assign(one, True, "assign", False) 252 # TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing 253 # value as a keyword argument. 254 v.assign(one, use_locking=True, name="assign", read_value=False) 255 v.assign_add(one, True, "assign", False) 256 v.assign_add(one, use_locking=True, name="assign", read_value=False) 257 v.assign_sub(one, True, "assign", False) 258 v.assign_sub(one, use_locking=True, name="assign", read_value=False) 259 # Return something for graph mode to fetch. 260 return constant_op.constant(1) 261 262 self.evaluate(variables_lib.global_variables_initializer()) 263 if not (synchronization == variables_lib.VariableSynchronization.ON_READ 264 and aggregation == variables_lib.VariableAggregation.SUM): 265 self.evaluate(distribution.experimental_local_results(assign())) 266 if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and 267 context.executing_eagerly()): 268 self.evaluate( 269 distribution.experimental_local_results(distribution.run(assign))) 270 271 def testStrategyExtendedUpdate(self, distribution, synchronization, 272 aggregation): 273 if len(distribution.extended.parameter_devices) != 2: 274 self.skipTest("n/a: needs exactly two parameter devices") 275 if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and 276 aggregation != variables_lib.VariableAggregation.NONE): 277 self.skipTest("n/a: doesn't apply to ON_WRITE variable with aggregation") 278 with distribution.scope(): 279 v = variables_lib.Variable( 280 0., synchronization=synchronization, aggregation=aggregation) 281 value = values_lib.PerReplica([1., 2.]) 282 283 assign_fn = lambda var, value: var.assign(value) 284 self.evaluate(distribution.extended.update(v, assign_fn, args=(value,))) 285 self.assertAllEqual(self.evaluate(v.values), [1., 2.]) 286 287 assign_add_fn = lambda var, value: var.assign_add(value) 288 self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,))) 289 self.assertAllEqual(self.evaluate(v.values), [2., 4.]) 290 291 assign_sub_fn = lambda var, value: var.assign_sub(value) 292 self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,))) 293 self.assertAllEqual(self.evaluate(v.values), [1., 2.]) 294 295 read_assign_fn = lambda var, value: var.assign_add(var.value() + var. 296 read_value()) 297 self.evaluate( 298 distribution.extended.update(v, read_assign_fn, args=(value,))) 299 self.assertAllEqual(self.evaluate(v.values), [3., 6.]) 300 301 def testSaveNonDistributed(self, distribution, synchronization, aggregation): 302 # This test verifies that the DistributedVariable behave like the primary 303 # variable when saving a non-distributed version of the model (the default). 304 # The test asserts that the function traced under SaveContext has no device 305 # annotations and only reference the primary component of the variable. Note 306 # that please avoid capturing other eager tensors in this test to make the 307 # assertion easy. 308 309 if isinstance(distribution.extended, 310 parameter_server_strategy.ParameterServerStrategyExtended): 311 self.skipTest("b/148689177: AggregatingVariable doesn't " 312 "conform to Variable interface well") 313 314 # tf.function requires the return value to be Tensors, which is not always 315 # case for properties and methods of Variable, so we simply discard the 316 # return values. 317 def _discard_return(f): 318 f() 319 return 320 321 def _test(f, v): 322 # This verifies that the function under SaveContext: 323 # - contains no device annotations. 324 # - only references the primary component of the variable. 325 g = def_function.function(lambda: _discard_return(f)) 326 options = save_options.SaveOptions( 327 experimental_variable_policy=save_options.VariablePolicy.NONE) 328 with save_context.save_context(options): 329 # The graph should contain no device. 330 graph = g.get_concrete_function().graph 331 for op in graph.get_operations(): 332 self.assertEqual(op.device, "", msg=str(op)) 333 # The function should only capture the primary variable. Note that it 334 # may not have captures, e.g. v.aggregation. 335 captures = list(graph.captures) 336 self.assertLessEqual(len(captures), 1) 337 if graph.captures: 338 self.assertIs(captures[0][0], v._primary.handle) 339 340 def _assert(cond): 341 return control_flow_ops.Assert(cond, [cond]) 342 343 with distribution.scope(): 344 # We use four variables for convenience reasons. They have no special 345 # meaning. 346 # - v is used whenever possible. 347 # - w is used for scatter and gather, which require the variable to be 348 # non-scalar. 349 # - y is used when the dtype needs to be integer. Note that aggregation 350 # cannot be MEAN for integers. 351 v = variables_lib.Variable( 352 0., 353 synchronization=synchronization, 354 aggregation=aggregation, 355 trainable=True) 356 w = variables_lib.Variable([0., 0., 0.], 357 synchronization=synchronization, 358 aggregation=aggregation, 359 trainable=True) 360 if aggregation != variables_lib.VariableAggregation.MEAN: 361 y = variables_lib.Variable( 362 0, synchronization=synchronization, aggregation=aggregation) 363 364 # pylint: disable=g-long-lambda 365 366 # tf.Variable properties. 367 _test(lambda: self.assertEqual(v.aggregation, aggregation), v) 368 _test(lambda: self.assertIs(v.constraint, None), v) 369 # TODO(crccw): should we raise an error instead? 370 _test(lambda: self.assertEqual(v.device, v._primary.device), v) 371 _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v) 372 if not context.executing_eagerly(): 373 _test(lambda: self.assertIs(v.graph, v._primary.graph), v) 374 if not context.executing_eagerly(): 375 _test(lambda: _assert(v.initial_value == 0), v) 376 _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v) 377 _test(lambda: self.assertEqual(v.name, "Variable:0"), v) 378 if not context.executing_eagerly(): 379 _test(lambda: self.assertIs(v.op, v._primary.op), v) 380 _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v) 381 _test(lambda: self.assertEqual(v.synchronization, synchronization), v) 382 _test(lambda: self.assertEqual(v.trainable, True), v) 383 384 # tf.Variable methods. 385 _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v) 386 _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v) 387 _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v) 388 # TODO(b/148689177): Implement batch_scatter_update. 389 # count_up_to() is skipped since it's deprecated. 390 # eval() is skipped since it shouldn't called in a tf.function. 391 # experimental_ref() is skipped since it's deprecated. 392 # from_proto() is skipped since it shouldn't called in a tf.function. 393 # TODO(b/148689177): Implement gather_nd. 394 _test( 395 lambda: check_ops.assert_equal_v2(v.get_shape(), 396 tensor_shape.TensorShape(())), v) 397 # initialized_value() is skipped since it shouldn't called in a tf.function. 398 # load() is skipped since it shouldn't called in a tf.function. 399 _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v) 400 # ref() is skipped since it shouldn't called in a tf.function. 401 _test( 402 lambda: check_ops.assert_equal_v2( 403 w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])), 404 [1., 0., 2.]), w) 405 _test( 406 lambda: check_ops.assert_equal_v2( 407 w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])), 408 [0.25, 0., 1.]), w) 409 _test( 410 lambda: check_ops.assert_equal_v2( 411 w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])), 412 [0.25, 1., 1.]), w) 413 _test( 414 lambda: check_ops.assert_equal_v2( 415 w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])), 416 [0.25, 0.5, 1.]), w) 417 _test( 418 lambda: check_ops.assert_equal_v2( 419 w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])), 420 [0.5, 0.25, 1.]), w) 421 # TODO(b/148689177): Implement scatter_nd_* 422 _test( 423 lambda: check_ops.assert_equal_v2( 424 w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])), 425 [-1.5, -0.25, 1.]), w) 426 _test( 427 lambda: check_ops.assert_equal_v2( 428 w.scatter_update( 429 _make_index_slices(values=[2., 0.5], indices=[0, 1])), 430 [2., 0.5, 1.]), w) 431 # set_shape() is skipped since ResourceVariable doesn't implement it. 432 # to_proto() is skipped since it shouldn't called in a tf.function. 433 _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v) 434 435 # DistributedVariable should be treated as ResourceVariable, so it needs to 436 # conform to ResourceVariable interface as well. 437 _test(lambda: self.assertIs(v.handle, v._primary.handle), v) 438 439 # Convert to tensor. 440 _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v) 441 442 # Control dependency. 443 def _with_control_dep(): 444 with ops.control_dependencies([v.assign(1.)]): 445 return array_ops.identity(1) 446 447 _test(_with_control_dep, v) 448 449 # Operator overloads. 450 _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v) 451 _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v) 452 _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v) 453 _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v) 454 _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v) 455 _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v) 456 _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v) 457 _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v) 458 _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v) 459 _test( 460 lambda: check_ops.assert_equal_v2( 461 math_ops.cast(v / 2., dtypes.float32), 3.5), v) 462 _test( 463 lambda: check_ops.assert_equal_v2( 464 math_ops.cast(14. / v, dtypes.float32), 2.), v) 465 _test(lambda: _assert(v < 12.), v) 466 _test(lambda: _assert(v <= 12.), v) 467 _test(lambda: _assert(not v > 12.), v) 468 _test(lambda: _assert(not v >= 12.), v) 469 _test(lambda: _assert(not 12. < v), v) 470 _test(lambda: _assert(not 12. <= v), v) 471 _test(lambda: _assert(12. > v), v) 472 _test(lambda: _assert(12. >= v), v) 473 _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v) 474 _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v) 475 _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v) 476 477 # Operator overloads that only works for integers. 478 if aggregation != variables_lib.VariableAggregation.MEAN: 479 _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y) 480 _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y) 481 _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y) 482 _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y) 483 _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y) 484 _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y) 485 _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y) 486 _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y) 487 _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y) 488 _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y) 489 _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y) 490 _test(lambda: check_ops.assert_equal_v2(-y, -7), y) 491 _test(lambda: check_ops.assert_equal_v2(~y, ~7), y) 492 493 # Index. 494 if isinstance(distribution.extended, tpu_strategy.TPUExtended): 495 # TODO(b/161572567): slice assignment doesn't work for TPU. 496 _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w) 497 else: 498 _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]), 499 w) 500 _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w) 501 502 # pylint: enable=g-long-lambda 503 504 def testUnsaveable(self, distribution, synchronization, aggregation, mode): 505 if isinstance(distribution.extended, 506 parameter_server_strategy.ParameterServerStrategyExtended): 507 self.skipTest("n/a: not appliable to AggregatingVariable") 508 if (isinstance(distribution, 509 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 510 and mode == "graph"): 511 self.skipTest("MWMS combinations tests do not work well in graph mode.") 512 if not distribution.extended._use_merge_call(): 513 self.skipTest("Unsupported combination.") 514 with distribution.scope(): 515 v = variables_lib.Variable([1., 1.], 516 synchronization=synchronization, 517 aggregation=aggregation) 518 519 with self.cached_session(): 520 self.evaluate(variables_lib.global_variables_initializer()) 521 522 export_dir = self.get_temp_dir() 523 524 def _assert_unsaveable(f): 525 # Ignore if it cannot be traced. Certain combinations are not supported or 526 # yet or not allowed. 527 try: 528 f = def_function.function(f).get_concrete_function() 529 except (NotImplementedError, ValueError): 530 return 531 with self.assertRaisesRegex(ValueError, "f_with_input_signature"): 532 save.save(v, export_dir, signatures=f) 533 534 _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.]))) 535 _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.]))) 536 _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.]))) 537 _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0]))) 538 _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0]))) 539 _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0]))) 540 _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0]))) 541 _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0]))) 542 _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0]))) 543 _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0]))) 544 # Reading a ON_READ variable should be unsaveable if either: 545 # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM. 546 # 2) aggregation is SUM. 547 if (synchronization == variables_lib.VariableSynchronization.ON_READ and 548 (aggregation == variables_lib.VariableAggregation.SUM or 549 (not distribution.extended._use_merge_call()) or 550 (isinstance(distribution.extended, 551 collective_all_reduce_strategy.CollectiveAllReduceExtended) 552 and aggregation == variables_lib.VariableAggregation.MEAN))): 553 _assert_unsaveable(v.read_value) 554 _assert_unsaveable(v.value) 555 _assert_unsaveable(lambda: ops.convert_to_tensor(v)) 556 else: 557 # Otherwise reading a variable should be saveable. 558 559 @def_function.function 560 def f(): 561 v.read_value() 562 v.value() 563 return ops.convert_to_tensor(v) 564 565 with self.cached_session(): 566 save.save(v, export_dir, signatures=f.get_concrete_function()) 567 568 569@combinations.generate( 570 combinations.combine( 571 distribution=[ 572 strategy_combinations.mirrored_strategy_with_one_cpu, 573 strategy_combinations.tpu_strategy, 574 ], 575 mode=["eager"])) 576class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase): 577 578 def testPackedVariable(self, distribution): 579 with distribution.scope(): 580 v0 = variables_lib.Variable(0.) 581 self.assertIsNone(v0._packed_var) 582 583 distribution._enable_packed_variable_in_eager_mode = True 584 with distribution.scope(): 585 v1 = variables_lib.Variable(0) 586 self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable) 587 588 devices = v1._devices 589 for i in range(1, len(devices)): 590 with distribute_lib.ReplicaContext(distribution, i): 591 v1.assign(i) 592 val = v1._get() 593 self.assertIsInstance(val, packed.PackedVarAndDevice) 594 self.assertEqual(val.device, devices[0]) 595 self.assertEqual(self.evaluate(val.read_value()), 0) 596 for i in range(0, len(devices)): 597 with distribute_lib.ReplicaContext(distribution, i): 598 val = v1._get() 599 self.assertIsInstance(val, packed.PackedVarAndDevice) 600 self.assertEqual(val.device, devices[i]) 601 self.assertEqual(self.evaluate(val.read_value()), i) 602 603 def testIgnorePackedVariableInSaveContext(self, distribution): 604 distribution._enable_packed_variable_in_eager_mode = True 605 with distribution.scope(): 606 v = variables_lib.Variable(0) 607 self.assertIsInstance(v._packed_variable, 608 packed.PackedDistributedVariable) 609 610 options = save_options.SaveOptions() 611 with save_context.save_context(options): 612 self.assertIsNone(v._packed_variable) 613 614 615def _make_index_slices(values, indices, dense_shape=None): 616 if dense_shape: 617 dense_shape = array_ops.identity(dense_shape) 618 return indexed_slices.IndexedSlices( 619 array_ops.identity(values), array_ops.identity(indices), dense_shape) 620 621 622if __name__ == "__main__": 623 ds_test_util.main() 624