xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/values_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the distributed values library."""
16
17import copy
18import os
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python import tf2
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.distribute import combinations
27from tensorflow.python.distribute import strategy_combinations
28from tensorflow.python.distribute import test_util as ds_test_util
29from tensorflow.python.distribute import tpu_strategy
30from tensorflow.python.distribute import tpu_values
31from tensorflow.python.distribute import values as values_lib
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.eager import test
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import sparse_tensor
39from tensorflow.python.framework import test_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import sparse_ops
43from tensorflow.python.ops import variable_scope
44from tensorflow.python.ops import variables as variables_lib
45from tensorflow.python.training import saver as saver_lib
46
47
48def _device_str(d):
49  return "/device:GPU:" + str(d)
50
51
52def _nested_value(d):
53  return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
54
55
56def mirrored_and_tpu_strategy_combinations():
57  return combinations.combine(
58      distribution=[
59          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
60          strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
61          strategy_combinations.tpu_strategy,
62          strategy_combinations.tpu_strategy_packed_var,
63          strategy_combinations.tpu_strategy_spmd,
64      ],
65      mode=["graph", "eager"])
66
67
68class DistributedValuesTest(test.TestCase, parameterized.TestCase):
69
70  @combinations.generate(
71      combinations.combine(
72          distribution=(strategy_combinations.all_strategies_minus_default +
73                        strategy_combinations.multiworker_strategies),
74          mode=["eager"]
75      ))
76  def testMakeDistributedValueFromTensor(self, distribution):
77    if not tf2.enabled():
78      self.skipTest("Only V2 is supported.")
79    single_value = constant_op.constant(1)
80    def value_fn(ctx):
81      del ctx
82      return single_value
83
84    distributed_values = (
85        distribution.experimental_distribute_values_from_function(value_fn))
86    self.assertAllEqual(
87        ds_test_util.gather(distribution, distributed_values),
88        constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
89
90  @combinations.generate(
91      combinations.combine(
92          distribution=(strategy_combinations.all_strategies_minus_default +
93                        strategy_combinations.multiworker_strategies),
94          mode=["eager"]
95      ))
96  def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
97    if not tf2.enabled():
98      self.skipTest("Only V2 is supported.")
99    array_value = np.array([1., 2., 3.])
100    def value_fn(ctx):
101      del ctx
102      return array_value
103
104    distributed_values = (
105        distribution.experimental_distribute_values_from_function(value_fn))
106    self.assertAllEqual(
107        ds_test_util.gather(distribution, distributed_values).numpy(),
108        [[1., 2., 3.]] * distribution.num_replicas_in_sync)
109
110  @combinations.generate(
111      combinations.combine(
112          distribution=(strategy_combinations.all_strategies_minus_default +
113                        strategy_combinations.multiworker_strategies),
114          mode=["eager"]
115      ))
116  def testMakeDistributedValueTupleConstant(self, distribution):
117    if not tf2.enabled():
118      self.skipTest("Only V2 is supported.")
119    tuple_value = (1., 2., 3.)
120    def value_fn(ctx):
121      del ctx
122      return tuple_value
123    distributed_values = (
124        distribution.experimental_distribute_values_from_function(value_fn))
125    distributed_values = ds_test_util.gather(distribution, distributed_values)
126
127    # Expected output for 2 replicas:
128    # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0])
129    expected = tuple([v for i in range(distribution.num_replicas_in_sync)]
130                     for v in tuple_value)
131    self.assertAllEqual(distributed_values, expected)
132
133  @combinations.generate(
134      combinations.combine(
135          distribution=(strategy_combinations.all_strategies_minus_default +
136                        strategy_combinations.multiworker_strategies),
137          mode=["eager"]
138      ))
139  def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
140    if not tf2.enabled():
141      self.skipTest("Only V2 is supported.")
142    tuple_value = (1., 2., 3.)
143    def value_fn(ctx):
144      per_replica = []
145      for val in tuple_value:
146        per_replica.append(val * ctx.replica_id_in_sync_group)
147      return tuple(per_replica)
148    distributed_values = (
149        distribution.experimental_distribute_values_from_function(value_fn))
150    distributed_values = ds_test_util.gather(distribution, distributed_values)
151
152    # Expected output for 2 replicas:
153    # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0])
154    expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)]
155                     for v in tuple_value)
156    self.assertAllEqual(distributed_values, expected)
157
158  # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because
159  # collective ops do not support SparseTensors.
160  @combinations.generate(
161      combinations.combine(
162          distribution=strategy_combinations.all_strategies_minus_default,
163          mode=["eager"]
164      ))
165  def testMakeDistributedValueSpareTensor(self, distribution):
166    if not tf2.enabled():
167      self.skipTest("Only V2 is supported.")
168    def value_fn(ctx):
169      del ctx
170      return sparse_tensor.SparseTensor(
171          indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
172
173    distributed_values = (
174        distribution.experimental_distribute_values_from_function(value_fn))
175    local_results = distribution.experimental_local_results(distributed_values)
176    for i in range(distribution.num_replicas_in_sync):
177      self.assertAllEqual(
178          sparse_ops.sparse_tensor_to_dense(local_results[i]),
179          [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])
180
181  @combinations.generate(
182      combinations.combine(
183          distribution=(strategy_combinations.all_strategies_minus_default +
184                        strategy_combinations.multiworker_strategies),
185          mode=["eager"]
186      ))
187  def testMakeDistributedValueExtractFromArray(self, distribution):
188    if not tf2.enabled():
189      self.skipTest("Only V2 is supported.")
190    multiple_values = range(distribution.num_replicas_in_sync)
191    def value_fn(ctx):
192      return multiple_values[ctx.replica_id_in_sync_group]
193    distributed_values = (
194        distribution.experimental_distribute_values_from_function(value_fn))
195    distributed_values = ds_test_util.gather(distribution, distributed_values)
196    expected = range(distribution.num_replicas_in_sync)
197    self.assertAllEqual(distributed_values, expected)
198
199  @combinations.generate(
200      combinations.combine(
201          distribution=(strategy_combinations.all_strategies_minus_default +
202                        strategy_combinations.multiworker_strategies),
203          mode=["eager"]
204      ))
205  def testMakeDistributedValueAndRun(self, distribution):
206    if not tf2.enabled():
207      self.skipTest("Only V2 is supported.")
208
209    @def_function.function
210    def run():
211      multiple_values = range(distribution.num_replicas_in_sync)
212      def value_fn(ctx):
213        return multiple_values[ctx.replica_id_in_sync_group]
214      distributed_values = (
215          distribution.experimental_distribute_values_from_function(value_fn))
216
217      def computation(x):
218        return math_ops.square(x)
219
220      outputs = ds_test_util.gather(
221          distribution,
222          distribution.run(computation, args=(distributed_values,)))
223      return outputs
224
225    results = run()
226
227    expected = [i**2 for i in range(distribution.num_replicas_in_sync)]
228    self.assertAllEqual(results, expected)
229
230  @combinations.generate(
231      combinations.combine(
232          distribution=[
233              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
234              strategy_combinations
235              .mirrored_strategy_with_two_gpus_no_merge_call,
236              strategy_combinations.tpu_strategy,
237              strategy_combinations.tpu_strategy_packed_var,
238              strategy_combinations.central_storage_strategy_with_two_gpus,
239          ] + strategy_combinations.multiworker_strategies,
240          mode=["eager"]))
241  def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
242    if not tf2.enabled():
243      self.skipTest("Only V2 is supported.")
244    def value_fn(ctx):
245      del ctx
246      return constant_op.constant(1.0)
247    distributed_values = (
248        distribution.experimental_distribute_values_from_function(value_fn))
249    default_device = array_ops.identity(constant_op.constant(1.0)).device
250    for i in range(len(distribution.extended.worker_devices)):
251      self.assertAllEqual(distributed_values._values[i].device, default_device)
252
253  @combinations.generate(
254      combinations.combine(
255          distribution=[
256              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
257              strategy_combinations
258              .mirrored_strategy_with_two_gpus_no_merge_call,
259              strategy_combinations.tpu_strategy,
260              strategy_combinations.tpu_strategy_packed_var,
261              strategy_combinations.central_storage_strategy_with_two_gpus,
262          ] + strategy_combinations.multiworker_strategies,
263          mode=["eager"],
264          op_type=[constant_op.constant, array_ops.identity]))
265  def testMakeDistributedValueExplicitDevicePlacement(self, distribution,
266                                                      op_type):
267    if not tf2.enabled():
268      self.skipTest("Only V2 is supported.")
269    worker_devices = distribution.extended.worker_devices
270    def value_fn(ctx):
271      # In multi client setup, worker_devices is just the devices on that
272      # worker.
273      worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices)
274      with ops.device(worker_devices[worker_device_id]):
275        return op_type(1.0)
276
277    distributed_values = (
278        distribution.experimental_distribute_values_from_function(value_fn))
279    for i in range(len(distribution.extended.worker_devices)):
280      self.assertAllEqual(distributed_values._values[i].device,
281                          worker_devices[i])
282
283
284class PerReplicaTest(test.TestCase, parameterized.TestCase):
285
286  @combinations.generate(
287      combinations.combine(
288          distribution=[
289              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
290              strategy_combinations
291              .mirrored_strategy_with_two_gpus_no_merge_call,
292              strategy_combinations.tpu_strategy,
293              strategy_combinations.tpu_strategy_packed_var,
294              strategy_combinations.central_storage_strategy_with_two_gpus,
295          ] + strategy_combinations.multiworker_strategies,
296          mode=["eager"]))
297  def testUsePerReplicaInvalidContextGivesError(self, distribution):
298    if not tf2.enabled():
299      self.skipTest("Only V2 is supported.")
300    multiple_values = range(distribution.num_replicas_in_sync)
301    def value_fn(ctx):
302      return multiple_values[ctx.replica_id_in_sync_group]
303    distributed_values = (
304        distribution.experimental_distribute_values_from_function(value_fn))
305    with self.assertRaisesRegex(ValueError, "not inside a replica context"):
306      math_ops.cast(distributed_values, dtypes.float32)
307
308
309class PerWorkerResourceTest(test.TestCase, parameterized.TestCase):
310
311  @combinations.generate(
312      combinations.combine(dataset_fn_as_tf_function=[True, False]))
313  def testMapFnTracing(self, dataset_fn_as_tf_function):
314    # For a PerWorkerResource to correctly behave when used in dataset.map,
315    # it has to be that the map_fn is not traced only once such that
316    # PerWorkerResource.local_table can return the correct resource. This test
317    # can detect the potential breakage of this behavior on TAP.
318    self._traced_once = 0
319
320    def map_fn(x):
321      self._traced_once += 1
322      return x
323
324    def dataset_fn():
325      dataset = dataset_ops.DatasetV2.from_tensors([0, 1, 2]).repeat().batch(
326          2, drop_remainder=True)
327      dataset = dataset.map(map_fn)
328      return dataset
329
330    datasets = []
331    number_of_input_pipelines = 5
332
333    if dataset_fn_as_tf_function:
334      dataset_fn = def_function.function(dataset_fn)
335      expected_tracing_times = 1
336    else:
337      expected_tracing_times = number_of_input_pipelines
338
339    for _ in range(number_of_input_pipelines):
340      datasets.append(dataset_fn())
341
342    self.assertEqual(self._traced_once, expected_tracing_times)
343
344
345class DistributedDelegateTest(test.TestCase):
346
347  @test_util.run_in_graph_and_eager_modes
348  def testGetAttr(self):
349    class Foo(object):
350
351      def __init__(self, x):
352        self.x = x
353
354    v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
355    self.assertEqual(7, v.x)
356    with self.assertRaises(AttributeError):
357      _ = v.y
358
359  @test_util.run_in_graph_and_eager_modes
360  def testOperatorOverride(self):
361    v = values_lib.DistributedDelegate((7, 8))
362    # v should act like int(7).
363    self.assertEqual(8, v + 1)
364    self.assertEqual(10, 3 + v)
365    self.assertEqual(14, v + v)
366    self.assertEqual(5, v - 2)
367    self.assertEqual(6, 13 - v)
368    self.assertEqual(0, v - v)
369    self.assertEqual(14, v * 2)
370    self.assertEqual(21, 3 * v)
371    self.assertEqual(49, v * v)
372    self.assertEqual(3.5, v / 2)
373    self.assertEqual(1.5, 10.5 / v)
374    self.assertEqual(3, v // 2)
375    self.assertEqual(2, 15 // v)
376    self.assertEqual(1, v % 2)
377    self.assertEqual(2, 16 % v)
378    # pylint: disable=g-generic-assert
379    self.assertTrue(v < 12)
380    self.assertTrue(v <= 12)
381    self.assertFalse(v > 12)
382    self.assertFalse(v >= 12)
383    self.assertFalse(12 < v)
384    self.assertFalse(12 <= v)
385    self.assertTrue(12 > v)
386    self.assertTrue(12 >= v)
387    # pylint: enable=g-generic-assert
388    self.assertEqual(3, v & 3)
389    self.assertEqual(3, 11 & v)
390    self.assertEqual(15, v | 8)
391    self.assertEqual(23, 16 | v)
392    self.assertEqual(4, v ^ 3)
393    self.assertEqual(12, 11 ^ v)
394    self.assertEqual(343, pow(v, 3))
395    self.assertEqual(3, pow(v, 3, 10))
396    self.assertEqual(128, pow(2, v))
397    self.assertEqual(-7, -v)
398    self.assertEqual(~7, ~v)
399    self.assertEqual(7, abs(v))
400    with self.assertRaises(TypeError):
401      _ = v[2]
402
403  @test_util.run_in_graph_and_eager_modes
404  def testCopy(self):
405
406    class Foo(object):
407
408      def __init__(self, x):
409        self.x = x
410
411    v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
412    v_shallow_copy = copy.copy(v)
413    self.assertEqual(v.x, v_shallow_copy.x)
414    v_deep_copy = copy.deepcopy(v)
415    self.assertEqual(v.x, v_deep_copy.x)
416
417
418_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
419
420
421def _make_replica_local(method, strategy=None):
422  if strategy is None:
423    devices = ("/device:GPU:0", "/device:CPU:0")
424  else:
425    devices = strategy.extended.worker_devices
426
427  v = []
428  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
429    with ops.device(d):
430      v.append(variable_scope.get_variable(
431          name=n, initializer=init, use_resource=True))
432
433  if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
434    var_cls = tpu_values.TPUSyncOnReadVariable
435  else:
436    var_cls = values_lib.SyncOnReadVariable
437  replica_local = var_cls(strategy, v, method)
438  return v, replica_local
439
440
441class DistributedVariableTest(test.TestCase, parameterized.TestCase):
442
443  def _assign_replica_local(self, v, new):
444    for var, n in zip(v, new):
445      with ops.device(var.device):
446        self.evaluate(var.assign(n))
447
448  def _save_return_saver(self, sess, var):
449    saver = saver_lib.Saver(var_list=[var])
450    test_dir = self.get_temp_dir()
451    prefix = os.path.join(test_dir, "ckpt")
452    return saver.save(sess, prefix), saver
453
454  def _save(self, sess, var):
455    save_path, _ = self._save_return_saver(sess, var)
456    return save_path
457
458  config = config_pb2.ConfigProto()
459  config.allow_soft_placement = True
460
461  @test_util.run_in_graph_and_eager_modes(config=config)
462  def testProperties(self):
463    if context.num_gpus() < 1 and context.executing_eagerly():
464      self.skipTest("A GPU is not available for this test in eager mode.")
465    v, replica_local = _make_replica_local(
466        variable_scope.VariableAggregation.SUM)
467
468    self.assertEqual(v[0].constraint, replica_local.constraint)
469    self.assertEqual(v[0].name, replica_local.name)
470    self.assertEqual(v[0].dtype, replica_local.dtype)
471    self.assertEqual(v[0].shape, replica_local.shape)
472    self.assertEqual(variable_scope.VariableAggregation.SUM,
473                     replica_local.aggregation)
474
475  @combinations.generate(
476      combinations.combine(
477          distribution=[
478              strategy_combinations.mirrored_strategy_with_gpu_and_cpu
479          ],
480          mode=["eager"]))
481  def testCanPassToDefFun(self, distribution):
482
483    @def_function.function
484    def add1(x):
485      return x + 1.
486
487    with distribution.scope():
488      v = variables_lib.Variable(
489          1.,
490          aggregation=variables_lib.VariableAggregation.MEAN,
491          synchronization=variables_lib.VariableSynchronization.ON_READ)
492
493    self.assertEqual(2., self.evaluate(add1(v)))
494
495  @combinations.generate(mirrored_and_tpu_strategy_combinations())
496  def testTensorConversion(self, distribution):
497    with context.graph_mode():
498      _, replica_local = _make_replica_local(
499          variable_scope.VariableAggregation.SUM, distribution)
500      converted = ops.convert_to_tensor(replica_local, as_ref=False)
501      self.assertIsInstance(converted, ops.Tensor)
502      self.assertEqual(converted.dtype, replica_local.dtype)
503
504      converted = ops.convert_to_tensor(replica_local, as_ref=True)
505      # Resources variable are converted to tensors as well when as_ref is True.
506      self.assertIsInstance(converted, ops.Tensor)
507      self.assertEqual(converted.dtype, replica_local.dtype)
508
509  @combinations.generate(combinations.combine(
510      distribution=[
511          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
512          strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
513          strategy_combinations.tpu_strategy,
514          strategy_combinations.tpu_strategy_packed_var,
515      ], mode=["eager"]))
516  def testValueInCrossReplicaContext(self, distribution):
517    value_list, replica_local = _make_replica_local(
518        variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution)
519
520    self.assertIsInstance(replica_local.value(), ops.Tensor)
521    self.assertEqual(self.evaluate(replica_local.value()),
522                     self.evaluate(value_list[0].value()))
523
524  @combinations.generate(
525      combinations.combine(
526          distribution=[
527              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
528              strategy_combinations.tpu_strategy_packed_var,
529          ],
530          mode=["eager"]))
531  def testValueInDefaultReplicaContext(self, distribution):
532    with distribution.scope():
533      v1 = variables_lib.Variable(
534          0.0,
535          aggregation=variables_lib.VariableAggregation.SUM,
536          synchronization=variables_lib.VariableSynchronization.ON_READ)
537      v2 = variables_lib.Variable(
538          0.0,
539          aggregation=variables_lib.VariableAggregation.SUM,
540          synchronization=variables_lib.VariableSynchronization.ON_READ)
541
542    @def_function.function
543    def replica_fn():
544      v1.assign_add(1.0)
545      v2.assign_add(2.0)
546
547    distribution.run(replica_fn)
548    sum_v = v1 + v2
549    self.assertEqual(sum_v, 6.0)
550
551  @combinations.generate(
552      combinations.combine(
553          distribution=[
554              strategy_combinations.tpu_strategy_packed_var,
555          ],
556          mode=["eager"]))
557  def testValueInFunctionCrossReplicaContext(self, distribution):
558    with distribution.scope():
559      v1 = variables_lib.Variable(
560          0.0,
561          aggregation=variables_lib.VariableAggregation.NONE,
562          synchronization=variables_lib.VariableSynchronization.ON_WRITE)
563
564    @def_function.function
565    def assign_fn():
566      v1.assign(1.0)
567
568    assign_fn()
569    self.assertEqual(v1, 1.0)
570
571    # Make sure the function graph has composite variable as inputs.
572    graph_def = assign_fn.get_concrete_function().graph.as_graph_def()
573    self.assertRegex(str(graph_def), "device:COMPOSITE:0")
574
575  @combinations.generate(
576      combinations.combine(
577          distribution=[
578              strategy_combinations.tpu_strategy_packed_var,
579          ],
580          mode=["eager"]))
581  def testReplicatedValueNameDeterministic(self, distribution):
582    with distribution.scope():
583      v1 = variables_lib.Variable(0.0, name="test_var_1")
584      v2 = variables_lib.Variable(0.0, name="test_var_2")
585
586    def fn():
587      v1.assign_add(1.0)
588      v2.assign_add(2.0)
589      return v1 + v2
590
591    @def_function.function
592    def dist_run_fn():
593      a = distribution.run(fn)
594      return a
595
596    concrete_fn = dist_run_fn.get_concrete_function()
597    inputs = concrete_fn.graph.inputs
598    self.assertLen(inputs, 2)
599    # Before cl/433948982, input name will include a non-deterministic uid,
600    # e.g. "test_var_1_139726389910864/handle/inputs_0:0"
601    self.assertEqual(inputs[0].name, "test_var_1/handle/inputs_0:0")
602    self.assertEqual(inputs[1].name, "test_var_2/handle/inputs_0:0")
603
604  @combinations.generate(mirrored_and_tpu_strategy_combinations())
605  def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
606    with self.cached_session() as sess:
607      v, replica_local = _make_replica_local(
608          variable_scope.VariableAggregation.SUM, distribution)
609
610      # Overwrite the initial values.
611      self._assign_replica_local(v, [3., 4.])
612
613      with distribution.scope():
614        # Saves the current value of v[0] + v[1], 7.
615        save_path, saver = self._save_return_saver(sess, replica_local)
616
617        # Change the values between save and restore.
618        self._assign_replica_local(v, [5., 6.])
619
620        # Restores the saved value of 7. which gets divided equally
621        # between the variables.
622        saver.restore(sess, save_path)
623        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
624
625  @combinations.generate(mirrored_and_tpu_strategy_combinations())
626  def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
627    if context.num_gpus() < 1 and context.executing_eagerly():
628      self.skipTest("A GPU is not available for this test in eager mode.")
629
630    with self.cached_session() as sess:
631      v, replica_local = _make_replica_local(
632          variable_scope.VariableAggregation.MEAN, distribution)
633
634      # Overwrite the initial values.
635      self._assign_replica_local(v, [3., 4.])
636
637      with distribution.scope():
638        # Saves the current value of (v[0] + v[1])/2, 3.5.
639        save_path, saver = self._save_return_saver(sess, replica_local)
640
641        # Change the values between save and restore.
642        self._assign_replica_local(v, [5., 6.])
643
644        # Restores the saved value of 3.5 to both variables.
645        saver.restore(sess, save_path)
646        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
647
648  def _save_replica_local_mean(self, distribution):
649    """Save variables with mirroring, returns save_path."""
650    with self.session(graph=ops.Graph()) as sess:
651      v, replica_local = _make_replica_local(
652          variable_scope.VariableAggregation.MEAN, distribution)
653
654      # Overwrite the initial values.
655      self._assign_replica_local(v, [3., 4.])
656
657      with distribution.scope():
658        # Saves the current value of (v[0] + v[1])/2, 3.5
659        save_path = self._save(sess, replica_local)
660
661        # Change the values between save and restore.
662        self._assign_replica_local(v, [5., 6.])
663    return save_path
664
665  def _save_replica_local_sum(self, distribution):
666    """Save variables with mirroring, returns save_path."""
667    with self.session(graph=ops.Graph()) as sess:
668      v, replica_local = _make_replica_local(
669          variable_scope.VariableAggregation.SUM, distribution)
670
671      # Overwrite the initial values.
672      self._assign_replica_local(v, [1.5, 2.])
673
674      with distribution.scope():
675        # Saves the current value of v[0] + v[1], 3.5
676        save_path = self._save(sess, replica_local)
677
678        # Change the values between save and restore.
679        self._assign_replica_local(v, [5., 6.])
680    return save_path
681
682  def _save_normal(self):
683    """Save variables without mirroring, returns save_path."""
684    with self.session(graph=ops.Graph()) as sess:
685      var = variable_scope.get_variable(
686          name="v", initializer=1., use_resource=True)
687
688      # Overwrite the initial value.
689      self.evaluate(var.assign(3.5))
690
691      # Saves the current value of var, 3.5.
692      save_path = self._save(sess, var)
693
694      # Change the values between save and restore.
695      self.evaluate(var.assign(5.))
696    return save_path
697
698  def _restore_normal(self, save_path):
699    """Restore to variables without mirroring in a fresh graph."""
700    with self.session(graph=ops.Graph()) as sess:
701      var = variable_scope.get_variable(
702          name="v", initializer=7., use_resource=True)
703
704      # Overwrite the initial value.
705      self.evaluate(var.assign(8.))
706
707      # Restores the saved value of 3.5 to `var`.
708      saver = saver_lib.Saver(var_list=[var])
709      saver.restore(sess, save_path)
710      self.assertEqual(3.5, self.evaluate(var))
711
712  def _restore_replica_local_mean(self, save_path, distribution):
713    """Restore to variables with mirroring in a fresh graph."""
714    with self.session(graph=ops.Graph()) as sess:
715      v, replica_local = _make_replica_local(
716          variable_scope.VariableAggregation.MEAN, distribution)
717
718      # Overwrite the initial values.
719      self._assign_replica_local(v, [7., 8.])
720
721      with distribution.scope():
722        # Restores the saved value of 3.5 to both variables.
723        saver = saver_lib.Saver(var_list=[replica_local])
724        saver.restore(sess, save_path)
725        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
726
727  def _restore_replica_local_sum(self, save_path, distribution):
728    """Restore to variables with mirroring in a fresh graph."""
729    with self.session(graph=ops.Graph()) as sess:
730      v, replica_local = _make_replica_local(
731          variable_scope.VariableAggregation.SUM, distribution)
732
733      # Overwrite the initial values.
734      self._assign_replica_local(v, [7., 8.])
735
736      with distribution.scope():
737        # Restores the saved value of 3.5 to both variables.
738        saver = saver_lib.Saver(var_list=[replica_local])
739        saver.restore(sess, save_path)
740        self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
741
742  @combinations.generate(mirrored_and_tpu_strategy_combinations())
743  def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
744    save_path = self._save_replica_local_mean(distribution)
745    self._restore_replica_local_mean(save_path, distribution)
746
747  @combinations.generate(mirrored_and_tpu_strategy_combinations())
748  def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
749    save_path = self._save_replica_local_sum(distribution)
750    self._restore_replica_local_sum(save_path, distribution)
751
752  @combinations.generate(mirrored_and_tpu_strategy_combinations())
753  def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
754    save_path = self._save_replica_local_mean(distribution)
755    self._restore_normal(save_path)
756
757  @combinations.generate(mirrored_and_tpu_strategy_combinations())
758  def testSaveReplicaLocalSumRestoreNormal(self, distribution):
759    save_path = self._save_replica_local_sum(distribution)
760    self._restore_normal(save_path)
761
762  @combinations.generate(mirrored_and_tpu_strategy_combinations())
763  def testSaveNormalRestoreReplicaLocalMean(self, distribution):
764    save_path = self._save_normal()
765    self._restore_replica_local_mean(save_path, distribution)
766
767  @combinations.generate(mirrored_and_tpu_strategy_combinations())
768  def testSaveNormalRestoreReplicaLocalSum(self, distribution):
769    save_path = self._save_normal()
770    self._restore_replica_local_sum(save_path, distribution)
771
772
773if __name__ == "__main__":
774  ds_test_util.main()
775