xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/distributed_variable_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the distributed variables library."""
16
17import copy
18import os
19
20from absl.testing import parameterized
21from tensorflow.python.checkpoint import checkpoint as trackable_utils
22from tensorflow.python.distribute import collective_all_reduce_strategy
23from tensorflow.python.distribute import combinations
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import packed_distributed_variable as packed
27from tensorflow.python.distribute import parameter_server_strategy
28from tensorflow.python.distribute import ps_values
29from tensorflow.python.distribute import strategy_combinations
30from tensorflow.python.distribute import test_util as ds_test_util
31from tensorflow.python.distribute import tpu_strategy
32from tensorflow.python.distribute import values as values_lib
33from tensorflow.python.eager import context
34from tensorflow.python.eager import def_function
35from tensorflow.python.eager import test
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import indexed_slices
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import check_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.ops import variables as variables_lib
47from tensorflow.python.saved_model import save
48from tensorflow.python.saved_model import save_context
49from tensorflow.python.saved_model import save_options
50from tensorflow.python.types import core
51
52
53def _device_str(d):
54  return "/device:GPU:" + str(d)
55
56
57def _nested_value(d):
58  return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
59
60
61def mirrored_and_tpu_strategy_combinations():
62  return combinations.combine(
63      distribution=[
64          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
65          strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
66          strategy_combinations.tpu_strategy,
67          strategy_combinations.tpu_strategy_packed_var,
68      ],
69      mode=["graph", "eager"])
70
71
72@combinations.generate(
73    combinations.combine(
74        distribution=[
75            strategy_combinations.mirrored_strategy_with_one_cpu,
76            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
77            strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
78            strategy_combinations.tpu_strategy,
79            strategy_combinations.tpu_strategy_packed_var,
80            strategy_combinations.tpu_strategy_spmd,
81            strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
82            strategy_combinations.multi_worker_mirrored_2x1_cpu,
83            strategy_combinations.multi_worker_mirrored_2x1_gpu,
84            strategy_combinations.multi_worker_mirrored_2x2_gpu,
85            strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
86        ],
87        synchronization=[
88            variables_lib.VariableSynchronization.ON_READ,
89            variables_lib.VariableSynchronization.ON_WRITE,
90        ],
91        aggregation=[
92            variables_lib.VariableAggregation.MEAN,
93            variables_lib.VariableAggregation.SUM,
94            variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
95        ],
96        mode=["graph", "eager"],
97        use_var_policy=[True, False]))
98class DistributedVariableTest(test.TestCase, parameterized.TestCase):
99
100  def testExtendsVariable(self, distribution, synchronization, aggregation):
101    with distribution.scope():
102      v = variables_lib.Variable(
103          1., synchronization=synchronization, aggregation=aggregation)
104    self.assertIsInstance(v, variables_lib.Variable)
105
106  def testCheckpointing(self, distribution, synchronization, aggregation, mode):
107
108    if (isinstance(distribution,
109                   collective_all_reduce_strategy.CollectiveAllReduceStrategy)
110        and mode == "graph"):
111      self.skipTest("MWMS combinations tests do not work well in graph mode.")
112
113    with distribution.scope():
114      v = variables_lib.Variable(
115          constant_op.constant([1., 2., 3., 4]),
116          synchronization=synchronization,
117          aggregation=aggregation)
118
119    self.evaluate(v.initializer)
120    before_save = self.evaluate(v.read_value())
121
122    # Save random weights into checkpoint.
123    checkpoint = trackable_utils.Checkpoint(v=v)
124    prefix = os.path.join(self.get_temp_dir(), "ckpt")
125    with self.test_session():
126      save_path = checkpoint.save(prefix)
127
128    # Assign inverted value.
129    self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.])))
130    after_assign = self.evaluate(v.read_value())
131    self.assertNotAllClose(before_save, after_assign)
132
133    # Restore from the checkpoint.
134    with self.test_session():
135      checkpoint.restore(save_path).assert_consumed().run_restore_ops()
136    after_restore = self.evaluate(v)
137    self.assertAllClose(before_save, after_restore)
138
139  def testTraceback(self, distribution, synchronization, aggregation):
140    if context.executing_eagerly():
141      self.skipTest("does not apply to eager")
142    with distribution.scope():
143      variable_scope.get_variable(
144          name="testVar",
145          initializer=1.,
146          use_resource=True,
147          synchronization=synchronization,
148          aggregation=aggregation)
149      with self.assertRaisesRegex(ValueError,
150                                  "Variable testVar already exists"):
151        variable_scope.get_variable(
152            name="testVar",
153            initializer=1.,
154            use_resource=True,
155            synchronization=synchronization,
156            aggregation=aggregation)
157
158  def testSelectReplica(self, distribution, synchronization, aggregation):
159    with distribution.scope():
160      v = variables_lib.Variable(
161          1., synchronization=synchronization, aggregation=aggregation)
162    self.assertIs(v, distribute_utils.select_replica(0, v))
163
164  def testIsTensorLike(self, distribution, synchronization, aggregation):
165    if isinstance(distribution.extended,
166                  tpu_strategy.TPUExtended) and context.executing_eagerly():
167      self.skipTest("TPU doesn't support pure eager")
168
169    with distribution.scope():
170      v = variables_lib.Variable(
171          0., synchronization=synchronization, aggregation=aggregation)
172    # In cross replica context.
173    self.assertIsInstance(v, core.Tensor)
174    # In replica context.
175    distribution.run(lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
176
177  def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
178                                        aggregation):
179    if isinstance(distribution.extended, tpu_strategy.TPUExtended):
180      if context.executing_eagerly():
181        self.skipTest("TPU doesn't support pure eager")
182      else:
183        self.skipTest("b/152076846")
184
185    with distribution.scope():
186      v = variables_lib.Variable(
187          0., synchronization=synchronization, aggregation=aggregation)
188
189    def assert_is_tensor_like(v):
190      # We can't use Python literals because they are treated as non-distributed
191      # values is not allowed when aggregation is SUM. See
192      # `cross_device_ops.reduce_non_distributed_value`.
193      delta = array_ops.identity(1.)
194      self.assertIsInstance(v.assign(delta), core.Tensor)
195      self.assertIsInstance(v.assign_sub(delta), core.Tensor)
196      self.assertIsInstance(v.assign_add(delta), core.Tensor)
197
198    # In cross replica context we return a PerReplica which is not Tensor like
199    # all the time yet.
200    if (synchronization == variables_lib.VariableSynchronization.ON_READ and
201        aggregation != variables_lib.VariableAggregation.SUM):
202      assert_is_tensor_like(v)
203
204    # In replica context.
205    distribution.run(assert_is_tensor_like, args=(v,))
206
207  def testDeepCopy(self, distribution, synchronization, aggregation):
208    if not context.executing_eagerly():
209      self.skipTest("deepcopy only supported in eager mode")
210
211    with distribution.scope():
212      v = variables_lib.Variable(
213          0., synchronization=synchronization, aggregation=aggregation)
214      in_dist_copy = copy.deepcopy(v)
215
216    out_dist_copy = copy.deepcopy(v)
217
218    def assert_is_deep_copy(v1, v2):
219      self.assertIsInstance(v2, type(v1))
220      self.assertEqual(v1.aggregation, v2.aggregation)
221      self.assertEqual(v1.distribute_strategy, v2.distribute_strategy)
222      if isinstance(v1, ps_values.AggregatingVariable):
223        self.assertIsInstance(v2.get(), type(v1.get()))
224        self.assertNotEqual(id(v1.get()), id(v2.get()))
225      else:
226        if v1._policy:
227          self.assertNotEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
228        else:
229          self.assertEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
230        self.assertEqual(len(v1.values), len(v2.values))
231        for (v1v, v2v) in zip(v1.values, v2.values):
232          self.assertEqual(v1v.device, v2v.device)
233          self.assertNotEqual(id(v1v), id(v2v))
234          self.assertAllEqual(
235              self.evaluate(v1.values), self.evaluate(v2.values))
236
237    self.evaluate(variables_lib.global_variables_initializer())
238    if not isinstance(distribution.extended, tpu_strategy.TPUExtended):
239      distribution.run(assert_is_deep_copy, args=(v, in_dist_copy))
240      distribution.run(assert_is_deep_copy, args=(v, out_dist_copy))
241
242  def testAssignSignature(self, distribution, synchronization, aggregation):
243    # This test verifies assign*() can be called in the same way as normal
244    # variables.
245    with distribution.scope():
246      v = variables_lib.Variable(
247          0., synchronization=synchronization, aggregation=aggregation)
248
249      def assign():
250        one = constant_op.constant(1.)
251        v.assign(one, True, "assign", False)
252        # TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing
253        # value as a keyword argument.
254        v.assign(one, use_locking=True, name="assign", read_value=False)
255        v.assign_add(one, True, "assign", False)
256        v.assign_add(one, use_locking=True, name="assign", read_value=False)
257        v.assign_sub(one, True, "assign", False)
258        v.assign_sub(one, use_locking=True, name="assign", read_value=False)
259        # Return something for graph mode to fetch.
260        return constant_op.constant(1)
261
262      self.evaluate(variables_lib.global_variables_initializer())
263      if not (synchronization == variables_lib.VariableSynchronization.ON_READ
264              and aggregation == variables_lib.VariableAggregation.SUM):
265        self.evaluate(distribution.experimental_local_results(assign()))
266      if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and
267              context.executing_eagerly()):
268        self.evaluate(
269            distribution.experimental_local_results(distribution.run(assign)))
270
271  def testStrategyExtendedUpdate(self, distribution, synchronization,
272                                 aggregation):
273    if len(distribution.extended.parameter_devices) != 2:
274      self.skipTest("n/a: needs exactly two parameter devices")
275    if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and
276        aggregation != variables_lib.VariableAggregation.NONE):
277      self.skipTest("n/a: doesn't apply to ON_WRITE variable with aggregation")
278    with distribution.scope():
279      v = variables_lib.Variable(
280          0., synchronization=synchronization, aggregation=aggregation)
281    value = values_lib.PerReplica([1., 2.])
282
283    assign_fn = lambda var, value: var.assign(value)
284    self.evaluate(distribution.extended.update(v, assign_fn, args=(value,)))
285    self.assertAllEqual(self.evaluate(v.values), [1., 2.])
286
287    assign_add_fn = lambda var, value: var.assign_add(value)
288    self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,)))
289    self.assertAllEqual(self.evaluate(v.values), [2., 4.])
290
291    assign_sub_fn = lambda var, value: var.assign_sub(value)
292    self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,)))
293    self.assertAllEqual(self.evaluate(v.values), [1., 2.])
294
295    read_assign_fn = lambda var, value: var.assign_add(var.value() + var.
296                                                       read_value())
297    self.evaluate(
298        distribution.extended.update(v, read_assign_fn, args=(value,)))
299    self.assertAllEqual(self.evaluate(v.values), [3., 6.])
300
301  def testSaveNonDistributed(self, distribution, synchronization, aggregation):
302    # This test verifies that the DistributedVariable behave like the primary
303    # variable when saving a non-distributed version of the model (the default).
304    # The test asserts that the function traced under SaveContext has no device
305    # annotations and only reference the primary component of the variable. Note
306    # that please avoid capturing other eager tensors in this test to make the
307    # assertion easy.
308
309    if isinstance(distribution.extended,
310                  parameter_server_strategy.ParameterServerStrategyExtended):
311      self.skipTest("b/148689177: AggregatingVariable doesn't "
312                    "conform to Variable interface well")
313
314    # tf.function requires the return value to be Tensors, which is not always
315    # case for properties and methods of Variable, so we simply discard the
316    # return values.
317    def _discard_return(f):
318      f()
319      return
320
321    def _test(f, v):
322      # This verifies that the function under SaveContext:
323      #   - contains no device annotations.
324      #   - only references the primary component of the variable.
325      g = def_function.function(lambda: _discard_return(f))
326      options = save_options.SaveOptions(
327          experimental_variable_policy=save_options.VariablePolicy.NONE)
328      with save_context.save_context(options):
329        # The graph should contain no device.
330        graph = g.get_concrete_function().graph
331      for op in graph.get_operations():
332        self.assertEqual(op.device, "", msg=str(op))
333      # The function should only capture the primary variable. Note that it
334      # may not have captures, e.g. v.aggregation.
335      captures = list(graph.captures)
336      self.assertLessEqual(len(captures), 1)
337      if graph.captures:
338        self.assertIs(captures[0][0], v._primary.handle)
339
340    def _assert(cond):
341      return control_flow_ops.Assert(cond, [cond])
342
343    with distribution.scope():
344      # We use four variables for convenience reasons. They have no special
345      # meaning.
346      # - v is used whenever possible.
347      # - w is used for scatter and gather, which require the variable to be
348      # non-scalar.
349      # - y is used when the dtype needs to be integer. Note that aggregation
350      # cannot be MEAN for integers.
351      v = variables_lib.Variable(
352          0.,
353          synchronization=synchronization,
354          aggregation=aggregation,
355          trainable=True)
356      w = variables_lib.Variable([0., 0., 0.],
357                                 synchronization=synchronization,
358                                 aggregation=aggregation,
359                                 trainable=True)
360      if aggregation != variables_lib.VariableAggregation.MEAN:
361        y = variables_lib.Variable(
362            0, synchronization=synchronization, aggregation=aggregation)
363
364    # pylint: disable=g-long-lambda
365
366    # tf.Variable properties.
367    _test(lambda: self.assertEqual(v.aggregation, aggregation), v)
368    _test(lambda: self.assertIs(v.constraint, None), v)
369    # TODO(crccw): should we raise an error instead?
370    _test(lambda: self.assertEqual(v.device, v._primary.device), v)
371    _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v)
372    if not context.executing_eagerly():
373      _test(lambda: self.assertIs(v.graph, v._primary.graph), v)
374    if not context.executing_eagerly():
375      _test(lambda: _assert(v.initial_value == 0), v)
376    _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v)
377    _test(lambda: self.assertEqual(v.name, "Variable:0"), v)
378    if not context.executing_eagerly():
379      _test(lambda: self.assertIs(v.op, v._primary.op), v)
380    _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v)
381    _test(lambda: self.assertEqual(v.synchronization, synchronization), v)
382    _test(lambda: self.assertEqual(v.trainable, True), v)
383
384    # tf.Variable methods.
385    _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v)
386    _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v)
387    _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v)
388    # TODO(b/148689177): Implement batch_scatter_update.
389    # count_up_to() is skipped since it's deprecated.
390    # eval() is skipped since it shouldn't called in a tf.function.
391    # experimental_ref() is skipped since it's deprecated.
392    # from_proto() is skipped since it shouldn't called in a tf.function.
393    # TODO(b/148689177): Implement gather_nd.
394    _test(
395        lambda: check_ops.assert_equal_v2(v.get_shape(),
396                                          tensor_shape.TensorShape(())), v)
397    # initialized_value() is skipped since it shouldn't called in a tf.function.
398    # load() is skipped since it shouldn't called in a tf.function.
399    _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v)
400    # ref() is skipped since it shouldn't called in a tf.function.
401    _test(
402        lambda: check_ops.assert_equal_v2(
403            w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])),
404            [1., 0., 2.]), w)
405    _test(
406        lambda: check_ops.assert_equal_v2(
407            w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])),
408            [0.25, 0., 1.]), w)
409    _test(
410        lambda: check_ops.assert_equal_v2(
411            w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])),
412            [0.25, 1., 1.]), w)
413    _test(
414        lambda: check_ops.assert_equal_v2(
415            w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])),
416            [0.25, 0.5, 1.]), w)
417    _test(
418        lambda: check_ops.assert_equal_v2(
419            w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
420            [0.5, 0.25, 1.]), w)
421    # TODO(b/148689177): Implement scatter_nd_*
422    _test(
423        lambda: check_ops.assert_equal_v2(
424            w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
425            [-1.5, -0.25, 1.]), w)
426    _test(
427        lambda: check_ops.assert_equal_v2(
428            w.scatter_update(
429                _make_index_slices(values=[2., 0.5], indices=[0, 1])),
430            [2., 0.5, 1.]), w)
431    # set_shape() is skipped since ResourceVariable doesn't implement it.
432    # to_proto() is skipped since it shouldn't called in a tf.function.
433    _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v)
434
435    # DistributedVariable should be treated as ResourceVariable, so it needs to
436    # conform to ResourceVariable interface as well.
437    _test(lambda: self.assertIs(v.handle, v._primary.handle), v)
438
439    # Convert to tensor.
440    _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v)
441
442    # Control dependency.
443    def _with_control_dep():
444      with ops.control_dependencies([v.assign(1.)]):
445        return array_ops.identity(1)
446
447    _test(_with_control_dep, v)
448
449    # Operator overloads.
450    _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v)
451    _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v)
452    _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v)
453    _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v)
454    _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v)
455    _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v)
456    _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v)
457    _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v)
458    _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v)
459    _test(
460        lambda: check_ops.assert_equal_v2(
461            math_ops.cast(v / 2., dtypes.float32), 3.5), v)
462    _test(
463        lambda: check_ops.assert_equal_v2(
464            math_ops.cast(14. / v, dtypes.float32), 2.), v)
465    _test(lambda: _assert(v < 12.), v)
466    _test(lambda: _assert(v <= 12.), v)
467    _test(lambda: _assert(not v > 12.), v)
468    _test(lambda: _assert(not v >= 12.), v)
469    _test(lambda: _assert(not 12. < v), v)
470    _test(lambda: _assert(not 12. <= v), v)
471    _test(lambda: _assert(12. > v), v)
472    _test(lambda: _assert(12. >= v), v)
473    _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v)
474    _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v)
475    _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v)
476
477    # Operator overloads that only works for integers.
478    if aggregation != variables_lib.VariableAggregation.MEAN:
479      _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y)
480      _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y)
481      _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y)
482      _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y)
483      _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y)
484      _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y)
485      _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y)
486      _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y)
487      _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y)
488      _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y)
489      _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y)
490      _test(lambda: check_ops.assert_equal_v2(-y, -7), y)
491      _test(lambda: check_ops.assert_equal_v2(~y, ~7), y)
492
493    # Index.
494    if isinstance(distribution.extended, tpu_strategy.TPUExtended):
495      # TODO(b/161572567): slice assignment doesn't work for TPU.
496      _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w)
497    else:
498      _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]),
499            w)
500      _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w)
501
502    # pylint: enable=g-long-lambda
503
504  def testUnsaveable(self, distribution, synchronization, aggregation, mode):
505    if isinstance(distribution.extended,
506                  parameter_server_strategy.ParameterServerStrategyExtended):
507      self.skipTest("n/a: not appliable to AggregatingVariable")
508    if (isinstance(distribution,
509                   collective_all_reduce_strategy.CollectiveAllReduceStrategy)
510        and mode == "graph"):
511      self.skipTest("MWMS combinations tests do not work well in graph mode.")
512    if not distribution.extended._use_merge_call():
513      self.skipTest("Unsupported combination.")
514    with distribution.scope():
515      v = variables_lib.Variable([1., 1.],
516                                 synchronization=synchronization,
517                                 aggregation=aggregation)
518
519    with self.cached_session():
520      self.evaluate(variables_lib.global_variables_initializer())
521
522    export_dir = self.get_temp_dir()
523
524    def _assert_unsaveable(f):
525      # Ignore if it cannot be traced. Certain combinations are not supported or
526      # yet or not allowed.
527      try:
528        f = def_function.function(f).get_concrete_function()
529      except (NotImplementedError, ValueError):
530        return
531      with self.assertRaisesRegex(ValueError, "f_with_input_signature"):
532        save.save(v, export_dir, signatures=f)
533
534    _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.])))
535    _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.])))
536    _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.])))
537    _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0])))
538    _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0])))
539    _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0])))
540    _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0])))
541    _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0])))
542    _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0])))
543    _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0])))
544    # Reading a ON_READ variable should be unsaveable if either:
545    # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM.
546    # 2) aggregation is SUM.
547    if (synchronization == variables_lib.VariableSynchronization.ON_READ and
548        (aggregation == variables_lib.VariableAggregation.SUM or
549         (not distribution.extended._use_merge_call()) or
550         (isinstance(distribution.extended,
551                     collective_all_reduce_strategy.CollectiveAllReduceExtended)
552          and aggregation == variables_lib.VariableAggregation.MEAN))):
553      _assert_unsaveable(v.read_value)
554      _assert_unsaveable(v.value)
555      _assert_unsaveable(lambda: ops.convert_to_tensor(v))
556    else:
557      # Otherwise reading a variable should be saveable.
558
559      @def_function.function
560      def f():
561        v.read_value()
562        v.value()
563        return ops.convert_to_tensor(v)
564
565      with self.cached_session():
566        save.save(v, export_dir, signatures=f.get_concrete_function())
567
568
569@combinations.generate(
570    combinations.combine(
571        distribution=[
572            strategy_combinations.mirrored_strategy_with_one_cpu,
573            strategy_combinations.tpu_strategy,
574        ],
575        mode=["eager"]))
576class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
577
578  def testPackedVariable(self, distribution):
579    with distribution.scope():
580      v0 = variables_lib.Variable(0.)
581    self.assertIsNone(v0._packed_var)
582
583    distribution._enable_packed_variable_in_eager_mode = True
584    with distribution.scope():
585      v1 = variables_lib.Variable(0)
586      self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
587
588    devices = v1._devices
589    for i in range(1, len(devices)):
590      with distribute_lib.ReplicaContext(distribution, i):
591        v1.assign(i)
592    val = v1._get()
593    self.assertIsInstance(val, packed.PackedVarAndDevice)
594    self.assertEqual(val.device, devices[0])
595    self.assertEqual(self.evaluate(val.read_value()), 0)
596    for i in range(0, len(devices)):
597      with distribute_lib.ReplicaContext(distribution, i):
598        val = v1._get()
599        self.assertIsInstance(val, packed.PackedVarAndDevice)
600        self.assertEqual(val.device, devices[i])
601        self.assertEqual(self.evaluate(val.read_value()), i)
602
603  def testIgnorePackedVariableInSaveContext(self, distribution):
604    distribution._enable_packed_variable_in_eager_mode = True
605    with distribution.scope():
606      v = variables_lib.Variable(0)
607      self.assertIsInstance(v._packed_variable,
608                            packed.PackedDistributedVariable)
609
610    options = save_options.SaveOptions()
611    with save_context.save_context(options):
612      self.assertIsNone(v._packed_variable)
613
614
615def _make_index_slices(values, indices, dense_shape=None):
616  if dense_shape:
617    dense_shape = array_ops.identity(dense_shape)
618  return indexed_slices.IndexedSlices(
619      array_ops.identity(values), array_ops.identity(indices), dense_shape)
620
621
622if __name__ == "__main__":
623  ds_test_util.main()
624