xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/vars_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the distributed values library."""
16
17import itertools
18
19import uuid
20from absl.testing import parameterized
21from tensorflow.python.checkpoint import checkpoint as trackable_utils
22from tensorflow.python.checkpoint import checkpoint_management as ckpt_manager
23from tensorflow.python.distribute import collective_all_reduce_strategy
24from tensorflow.python.distribute import combinations
25from tensorflow.python.distribute import distribution_strategy_context as ds_context
26from tensorflow.python.distribute import strategy_combinations
27from tensorflow.python.distribute import strategy_test_lib
28from tensorflow.python.distribute import test_util
29from tensorflow.python.distribute import values
30from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import test
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import indexed_slices
37from tensorflow.python.framework import ops
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import random_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables as variables_lib
43from tensorflow.python.tpu import tpu_strategy_util
44from tensorflow.python.util import variable_utils
45
46
47def strategy_and_run_tf_function_combinations():
48  # Test the combination of different strategies and whether a tf.function
49  # is passed into strategy.run."""
50  # TODO(b/197981388): re-enable MWMS test
51  # return combinations.combine(
52  #     distribution=[
53  #         strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
54  #     ],
55  #     mode=["graph", "eager"],
56  #     experimental_run_tf_function=[True, False],
57  #     use_var_policy=[True, False]) +
58  return combinations.combine(
59      distribution=[
60          strategy_combinations.tpu_strategy,
61          strategy_combinations.tpu_strategy_packed_var,
62      ],
63      mode=["graph", "eager"],
64      experimental_run_tf_function=[True],
65      use_var_policy=[True, False])
66
67
68def strategy_with_var_policy():
69  return combinations.combine(
70      distribution=[
71          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
72          # TODO(b/197981388): re-enable MWMS test
73          # strategy_combinations.multi_worker_mirrored_2x1_cpu,
74          # strategy_combinations.multi_worker_mirrored_2x1_gpu,
75          strategy_combinations.tpu_strategy,
76          strategy_combinations.tpu_strategy_packed_var,
77      ],
78      mode=["graph", "eager"],
79      use_var_policy=[True, False])
80
81
82class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
83
84  @combinations.generate(strategy_and_run_tf_function_combinations())
85  def testAssign(self, distribution, experimental_run_tf_function):
86
87    def assign(fn, v, update_value, cross_replica):
88      update_fn = lambda: getattr(v, fn)(update_value)
89      if cross_replica:
90        return update_fn()
91      else:
92        if experimental_run_tf_function:
93          update_fn = def_function.function(update_fn)
94        return test_util.gather(distribution, distribution.run(update_fn))
95
96    updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
97    aggregations = [
98        variables_lib.VariableAggregation.NONE,
99        variables_lib.VariableAggregation.SUM,
100        variables_lib.VariableAggregation.MEAN,
101        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
102    ]
103    options = list(
104        x for x in itertools.product(updates, aggregations, [True, False]))
105    for update, aggregation, cross_replica in options:
106      # assign in replica context with SUM does not make sense cause you can
107      # just do value * num replicas error is 1. is not a distributed value and
108      # is unsupported for aggregation SUM
109      if (not cross_replica and aggregation ==
110          variables_lib.VariableAggregation.SUM):
111        continue
112      with distribution.scope():
113        v = variable_scope.variable(
114            0.,
115            aggregation=aggregation)
116      self.evaluate(variables_lib.global_variables_initializer())
117      fn, update_value = update
118      self.evaluate(assign(fn, v, update_value, cross_replica))
119      for component in v._values:
120        self.assertAllEqual(self.evaluate(component.read_value()),
121                            self.evaluate(array_ops.ones_like(component)))
122
123  @combinations.generate(strategy_and_run_tf_function_combinations())
124  def testAssignOnWriteVar(self, distribution, experimental_run_tf_function):
125
126    with distribution.scope():
127      v_to_assign = variable_scope.variable(
128          2., aggregation=variables_lib.VariableAggregation.MEAN)
129      v_to_assign_sub = variable_scope.variable(
130          -2., aggregation=variables_lib.VariableAggregation.MEAN)
131
132    def assign(fn, v, update_value, cross_replica):
133      update_fn = lambda: getattr(v, fn)(update_value)
134      if cross_replica:
135        return update_fn()
136      else:
137        if experimental_run_tf_function:
138          update_fn = def_function.function(update_fn)
139        return test_util.gather(distribution, distribution.run(update_fn))
140
141    updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
142               ("assign_sub", v_to_assign_sub)]
143    aggregations = [
144        variables_lib.VariableAggregation.NONE,
145        variables_lib.VariableAggregation.SUM,
146        variables_lib.VariableAggregation.MEAN,
147        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
148    ]
149    options = list(
150        x for x in itertools.product(updates, aggregations, [True, False]))
151    for update, aggregation, cross_replica in options:
152      # assign in replica context with SUM does not make sense cause you can
153      # just do value * num replicas error is 1. is not a distributed value and
154      # is unsupported for aggregation SUM
155      if aggregation == variables_lib.VariableAggregation.SUM:
156        continue
157      with distribution.scope():
158        v = variable_scope.variable(
159            0.,
160            aggregation=aggregation)
161      self.evaluate(variables_lib.global_variables_initializer())
162      fn, update_value = update
163      self.evaluate(assign(fn, v, update_value, cross_replica))
164      for component in v._values:
165        self.assertAllEqual(2.0, self.evaluate(component.read_value()))
166
167  @combinations.generate(strategy_and_run_tf_function_combinations())
168  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
169
170    if strategy_test_lib.is_tpu_strategy(distribution):
171      self.skipTest("Assigning PerReplica values is not supported. See"
172                    " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
173
174    with distribution.scope():
175      per_replica_value = values.PerReplica(
176          [constant_op.constant(2.0),
177           constant_op.constant(2.0)])
178      per_replica_sub_value = values.PerReplica(
179          [constant_op.constant(-2.0),
180           constant_op.constant(-2.0)])
181
182    def assign(fn, v, update_value, cross_replica):
183      update_fn = lambda: getattr(v, fn)(update_value)
184      if cross_replica:
185        return update_fn()
186      else:
187        if experimental_run_tf_function:
188          update_fn = def_function.function(update_fn)
189        return test_util.gather(distribution, distribution.run(update_fn))
190
191    updates = [("assign", per_replica_value), ("assign_add", per_replica_value),
192               ("assign_sub", per_replica_sub_value)]
193    # We don't support assigning PerReplica valus to vars in replica context
194    # with aggregation=NONE.
195    aggregations = [
196        variables_lib.VariableAggregation.SUM,
197        variables_lib.VariableAggregation.MEAN,
198        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
199    ]
200    options = list(
201        x for x in itertools.product(updates, aggregations, [True, False]))
202    for update, aggregation, cross_replica in options:
203      # assign in replica context with SUM does not make sense cause you can
204      # just do value * num replicas error is 1. is not a distributed value and
205      # is unsupported for aggregation SUM
206      if cross_replica:
207        # We don't support assigning PerReplica values to MirroredVariables in
208        # cross replica context
209        continue
210      with distribution.scope():
211        v = variable_scope.variable(
212            0.,
213            aggregation=aggregation)
214      self.evaluate(variables_lib.global_variables_initializer())
215      fn, update_value = update
216      self.evaluate(assign(fn, v, update_value, cross_replica))
217      if aggregation == variables_lib.VariableAggregation.SUM:
218        expected = 4.0
219      else:
220        expected = 2.0
221      for component in v._values:
222        self.assertAllEqual(expected, self.evaluate(component.read_value()))
223
224  @combinations.generate(strategy_with_var_policy())
225  def testValueInReplicaContext(self, distribution):
226    with distribution.scope():
227      v = variables_lib.Variable(
228          1., aggregation=variables_lib.VariableAggregation.MEAN)
229      self.evaluate(variables_lib.global_variables_initializer())
230
231      @def_function.function
232      def f():
233        with ops.control_dependencies([v.assign_add(1.)]):
234          return v.value()
235
236      results = self.evaluate(
237          test_util.gather(distribution, distribution.run(f)))
238      for value in results:
239        self.assertEqual(2., value)
240
241  @combinations.generate(strategy_with_var_policy())
242  def testValueInReplicaContextAssignDirectValue(self, distribution,
243                                                 use_var_policy):
244    with distribution.scope():
245      v = variables_lib.Variable(
246          1., aggregation=variables_lib.VariableAggregation.MEAN)
247      self.evaluate(variables_lib.global_variables_initializer())
248
249      @def_function.function
250      def f():
251        with ops.control_dependencies([v.assign_add(1.)]):
252          return v.value()
253
254      results = self.evaluate(
255          test_util.gather(distribution, distribution.run(f)))
256      for value in results:
257        self.assertEqual(2., value)
258
259  @combinations.generate(strategy_and_run_tf_function_combinations())
260  def testReadValueInReplicaContext(self, distribution,
261                                    experimental_run_tf_function):
262    aggregations = [
263        variables_lib.VariableAggregation.NONE,
264        variables_lib.VariableAggregation.SUM,
265        variables_lib.VariableAggregation.MEAN,
266        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
267    ]
268    for aggregation in aggregations:
269      with distribution.scope():
270        v = variable_scope.variable(
271            0.,
272            aggregation=aggregation)
273      self.evaluate(variables_lib.global_variables_initializer())
274      if experimental_run_tf_function:
275        read_var_fn = def_function.function(v.read_value)
276      else:
277        read_var_fn = v.read_value
278      results = self.evaluate(
279          test_util.gather(distribution, distribution.run(read_var_fn)))
280      for component, value in zip(v._values, results):
281        self.assertAllEqual(self.evaluate(component.read_value()), value)
282
283  @combinations.generate(strategy_and_run_tf_function_combinations())
284  def testReadValueInCrossReplicaContext(self, distribution,
285                                         experimental_run_tf_function):
286    aggregations = [
287        variables_lib.VariableAggregation.NONE,
288        variables_lib.VariableAggregation.SUM,
289        variables_lib.VariableAggregation.MEAN,
290        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
291    ]
292    for aggregation in aggregations:
293      with distribution.scope():
294        v = variable_scope.variable(
295            2.,
296            aggregation=aggregation)
297      self.evaluate(variables_lib.global_variables_initializer())
298
299      if experimental_run_tf_function:
300        read_var_fn = def_function.function(v.read_value)
301      else:
302        read_var_fn = v.read_value
303
304      results = read_var_fn()
305      for component in v._values:
306        self.assertEqual(self.evaluate(component.read_value()),
307                         self.evaluate(results))
308
309  @combinations.generate(strategy_with_var_policy())
310  def testAssignOutOfScope(self, distribution):
311    with distribution.scope():
312      mirrored = variables_lib.Variable(1.)
313    self.evaluate(mirrored.assign(3.))
314    self.assertEqual(self.evaluate(mirrored.read_value()), 3.)
315    for component in mirrored.values:
316      self.assertEqual(self.evaluate(component.read_value()), 3.)
317
318  @combinations.generate(strategy_with_var_policy())
319  def testInitializedToSameValueInsideEagerRun(self, distribution):
320    if not context.executing_eagerly(): self.skipTest("eager only test")
321    if isinstance(distribution.extended,
322                  collective_all_reduce_strategy.CollectiveAllReduceExtended):
323      self.skipTest("Test for more than 1 device per worker only.")
324    v = [None]
325
326    @def_function.function
327    def step():
328
329      def f():
330        if v[0] is None:
331          v[0] = variables_lib.Variable(random_ops.random_normal([]))
332
333      distribution.run(f)
334
335    context.set_global_seed(None)
336    step()
337    vals = self.evaluate(v[0].values)
338    self.assertAllEqual(vals[0], vals[1])
339
340  @combinations.generate(strategy_with_var_policy())
341  def testAggregationOnlyFirstReplica(self, distribution):
342    if isinstance(distribution.extended,
343                  collective_all_reduce_strategy.CollectiveAllReduceExtended):
344      self.skipTest("b/212945803")
345    with distribution.scope():
346      v = variable_scope.variable(
347          15.,
348          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
349          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
350    self.evaluate(variables_lib.global_variables_initializer())
351
352    @def_function.function
353    def assign():
354      ctx = ds_context.get_replica_context()
355      replica_id = ctx.replica_id_in_sync_group
356      return v.assign(math_ops.cast(replica_id, dtypes.float32))
357
358    per_replica_results = self.evaluate(
359        test_util.gather(distribution, distribution.run(assign)))
360    # The per-replica values should always match the first replicas value.
361    self.assertAllEqual(
362        array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
363        per_replica_results)
364
365  @combinations.generate(strategy_with_var_policy())
366  def testInitScope(self, distribution):
367    if not context.executing_eagerly(): self.skipTest("eager only")
368
369    class C(object):
370      pass
371
372    obj = C()
373    obj.w = None
374    obj.v = None
375
376    @def_function.function
377    def assign():
378      with ops.init_scope():
379        if obj.w is None:
380          obj.w = variables_lib.Variable(
381              0., aggregation=variables_lib.VariableAggregation.MEAN)
382          obj.v = variables_lib.Variable(
383              obj.w.read_value(),
384              aggregation=variables_lib.VariableAggregation.MEAN)
385          self.evaluate(variables_lib.global_variables_initializer())
386
387      return obj.v.assign_add(2.)
388
389    per_replica_results = self.evaluate(
390        test_util.gather(distribution, distribution.run(assign)))
391    self.assertAllEqual([2., 2.], per_replica_results)
392
393  @combinations.generate(strategy_with_var_policy())
394  def testOperatorOverride(self, distribution):
395
396    if not context.executing_eagerly() and isinstance(
397        distribution.extended,
398        collective_all_reduce_strategy.CollectiveAllReduceExtended):
399      self.skipTest("b/212954197")
400
401    with distribution.scope():
402      v = variable_scope.variable(
403          1, aggregation=variables_lib.VariableAggregation.SUM)
404      self.evaluate(variables_lib.global_variables_initializer())
405
406    self.assertEqual(2, self.evaluate(v + 1))
407
408    @def_function.function
409    def add():
410      return v + 1
411
412    per_replica_results = self.evaluate(
413        test_util.gather(distribution, distribution.run(add)))
414    self.assertAllEqual([2, 2], per_replica_results)
415
416  @combinations.generate(
417      combinations.combine(
418          strategy=[
419              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
420              strategy_combinations.tpu_strategy,
421              strategy_combinations.tpu_strategy_packed_var,
422              strategy_combinations.multi_worker_mirrored_2x1_cpu,
423              strategy_combinations.multi_worker_mirrored_2x1_gpu,
424          ],
425          mode=["eager"],
426          use_var_policy=[True, False]))
427  def testSaveAndRestoreOnWrite(self, strategy):
428    aggregation = [
429        variable_scope.VariableAggregation.NONE,
430        variable_scope.VariableAggregation.ONLY_FIRST_REPLICA,
431        variable_scope.VariableAggregation.SUM,
432        variable_scope.VariableAggregation.MEAN
433    ]
434    for agg in aggregation:
435      v_normal_restore = variables_lib.Variable(1.0)
436      v_normal_save = variables_lib.Variable(3.0)
437      with strategy.scope():
438        v_on_write = variables_lib.Variable(2.0, aggregation=agg)
439
440        # Save ONWRITE Restore ONWRITE
441        # Save
442        ckpt = trackable_utils.Checkpoint(var=v_on_write)
443        manager = ckpt_manager.CheckpointManager(
444            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
445        manager.save()
446        # Restore
447        ckpt.restore(manager.latest_checkpoint)
448        self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
449        self.assertEqual(2.0, self.evaluate(v_on_write.read_value()))
450
451        # Save Mirrored Restore Normal
452        # We've already saved Mirrored, so we only need to restore normal
453        ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
454        ckpt_normal.restore(manager.latest_checkpoint)
455        self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
456        self.assertEqual(2.0, self.evaluate(v_normal_restore.read_value()))
457
458        # Save Normal Restore Mirrored
459        # Save
460        ckpt = trackable_utils.Checkpoint(var=v_normal_save)
461        manager_2 = ckpt_manager.CheckpointManager(
462            ckpt, "/tmp/ckptckpt_" + str(uuid.uuid4()), max_to_keep=None)
463        manager_2.save()
464        # Restore
465        ckpt_on_write = trackable_utils.Checkpoint(var=v_on_write)
466        ckpt_on_write.restore(manager_2.latest_checkpoint)
467        self.assertEqual(3.0, self.evaluate(v_on_write._values[0]))
468        self.assertEqual(3.0, self.evaluate(v_on_write.read_value()))
469
470
471ms_combination = combinations.combine(
472    distribution=[strategy_combinations.mirrored_strategy_with_gpu_and_cpu],
473    mode=["graph", "eager"])
474tpu_combination = combinations.combine(
475    distribution=[strategy_combinations.tpu_strategy_packed_var],
476    mode=["graph", "eager"])
477
478
479class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):
480
481  @combinations.generate(ms_combination)
482  def testScatterSub(self, distribution):
483    with distribution.scope():
484      v = variables_lib.Variable(
485          [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
486    self.evaluate(v.initializer)
487
488    @def_function.function
489    def scatter_sub():
490      ctx = ds_context.get_replica_context()
491      replica_id = ctx.replica_id_in_sync_group
492      value = indexed_slices.IndexedSlices(
493          values=array_ops.stack([
494              math_ops.cast(replica_id, dtypes.float32),
495              math_ops.cast(replica_id + 1, dtypes.float32)
496          ]),
497          indices=array_ops.stack([replica_id, replica_id + 1]),
498          dense_shape=(3,))
499      return v.scatter_sub(value)
500
501    per_replica_results = self.evaluate(
502        distribution.experimental_local_results(
503            distribution.run(scatter_sub)))
504    self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
505
506  @combinations.generate(ms_combination)
507  def testScatterAdd(self, distribution):
508    with distribution.scope():
509      v = variables_lib.Variable(
510          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
511    self.evaluate(v.initializer)
512
513    @def_function.function
514    def scatter_add():
515      ctx = ds_context.get_replica_context()
516      replica_id = ctx.replica_id_in_sync_group
517      value = indexed_slices.IndexedSlices(
518          values=array_ops.stack([replica_id, replica_id + 1]),
519          indices=array_ops.stack([replica_id, replica_id + 1]),
520          dense_shape=(3,))
521      return v.scatter_add(value)
522
523    per_replica_results = self.evaluate(
524        test_util.gather(distribution, distribution.run(scatter_add)))
525    self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
526
527  @combinations.generate(ms_combination)
528  def testScatterDiv(self, distribution):
529    with distribution.scope():
530      v = variables_lib.Variable(
531          [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
532    self.evaluate(v.initializer)
533
534    @def_function.function
535    def scatter_div():
536      ctx = ds_context.get_replica_context()
537      replica_id = ctx.replica_id_in_sync_group
538      value = indexed_slices.IndexedSlices(
539          values=array_ops.reshape(replica_id + 2, [1]),
540          indices=array_ops.reshape(replica_id, [1]),
541          dense_shape=(3,))
542      return v.scatter_div(value)
543
544    per_replica_results = self.evaluate(
545        test_util.gather(distribution, distribution.run(scatter_div)))
546    self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
547
548  @combinations.generate(ms_combination)
549  def testScatterMul(self, distribution):
550    with distribution.scope():
551      v = variables_lib.Variable(
552          [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
553    self.evaluate(v.initializer)
554
555    @def_function.function
556    def scatter_mul():
557      ctx = ds_context.get_replica_context()
558      replica_id = ctx.replica_id_in_sync_group
559      value = indexed_slices.IndexedSlices(
560          values=array_ops.reshape(
561              math_ops.cast(replica_id + 2, dtypes.float32), [1]),
562          indices=array_ops.reshape(replica_id, [1]),
563          dense_shape=(3,))
564      return v.scatter_mul(value)
565
566    per_replica_results = self.evaluate(
567        test_util.gather(distribution, distribution.run(scatter_mul)))
568    self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
569
570  @combinations.generate(ms_combination)
571  def testScatterMin(self, distribution):
572    with distribution.scope():
573      v1 = variables_lib.Variable(
574          [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
575      v2 = variables_lib.Variable(
576          [0, 2, 0],
577          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
578    self.evaluate(variables_lib.global_variables_initializer())
579
580    @def_function.function
581    def scatter_min(v):
582      value = indexed_slices.IndexedSlices(
583          values=array_ops.identity([1]),
584          indices=array_ops.identity([1]),
585          dense_shape=(3,))
586      return v.scatter_min(value)
587
588    with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"):
589      self.evaluate(
590          test_util.gather(distribution,
591                           distribution.run(scatter_min, args=(v1,))))
592
593    per_replica_results = self.evaluate(
594        test_util.gather(distribution,
595                         distribution.run(scatter_min, args=(v2,))))
596    self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
597
598  @combinations.generate(ms_combination)
599  def testScatterMax(self, distribution):
600    with distribution.scope():
601      v1 = variables_lib.Variable(
602          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
603      v2 = variables_lib.Variable(
604          [0, 0, 0],
605          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
606    self.evaluate(variables_lib.global_variables_initializer())
607
608    @def_function.function
609    def scatter_max(v):
610      value = indexed_slices.IndexedSlices(
611          values=array_ops.identity([1]),
612          indices=array_ops.identity([0]),
613          dense_shape=(3,))
614      return v.scatter_max(value)
615
616    with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"):
617      self.evaluate(
618          test_util.gather(distribution,
619                           distribution.run(scatter_max, args=(v1,))))
620
621    per_replica_results = self.evaluate(
622        test_util.gather(distribution,
623                         distribution.run(scatter_max, args=(v2,))))
624    self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
625
626  @combinations.generate(ms_combination)
627  def testScatterUpdate(self, distribution):
628    with distribution.scope():
629      v1 = variables_lib.Variable(
630          [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
631      v2 = variables_lib.Variable(
632          [0, 0, 0],
633          aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
634    self.evaluate(variables_lib.global_variables_initializer())
635
636    @def_function.function
637    def scatter_update(v):
638      value = indexed_slices.IndexedSlices(
639          values=array_ops.identity([3]),
640          indices=array_ops.identity([1]),
641          dense_shape=(3,))
642      return v.scatter_update(value)
643
644    with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"):
645      self.evaluate(
646          test_util.gather(distribution,
647                           distribution.run(scatter_update, args=(v1,))))
648
649    per_replica_results = self.evaluate(
650        test_util.gather(distribution,
651                         distribution.run(scatter_update, args=(v2,))))
652    self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
653
654  @combinations.generate(ms_combination + tpu_combination)
655  def testScatterOpsWithNoneAggregation(self, distribution):
656
657    def assert_close(v, op, delta, expect):
658      scatter_op = getattr(v, op)
659
660      @def_function.function
661      def scatter_xxx():
662        return scatter_op(delta)
663
664      per_replica_results = self.evaluate(
665          variable_utils.convert_variables_to_tensors(
666              distribution.experimental_local_results(
667                  distribution.run(scatter_xxx))))
668      self.assertAllClose([expect, expect], per_replica_results)
669
670    with distribution.scope():
671      v = variables_lib.Variable(
672          [4.], aggregation=variables_lib.VariableAggregation.NONE)
673    self.evaluate(variables_lib.global_variables_initializer())
674
675    delta = indexed_slices.IndexedSlices(
676        values=array_ops.identity([2.]),
677        indices=array_ops.identity([0]),
678        dense_shape=(1,))
679
680    assert_close(v, "scatter_sub", delta, [2.])
681    assert_close(v, "scatter_add", delta, [4.])
682    assert_close(v, "scatter_max", delta, [4.])
683    assert_close(v, "scatter_min", delta, [2.])
684    assert_close(v, "scatter_mul", delta, [4.])
685    assert_close(v, "scatter_div", delta, [2.])
686    assert_close(v, "scatter_update", delta, [2.])
687
688  @combinations.generate(ms_combination + tpu_combination)
689  def testScatterOpsInCrossReplicaContext(self, distribution):
690    with distribution.scope():
691      v1 = variables_lib.Variable(
692          [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
693      v2 = variables_lib.Variable([1, 1, 1])
694    self.evaluate(variables_lib.global_variables_initializer())
695
696    value = indexed_slices.IndexedSlices(
697        values=array_ops.identity([2]),
698        indices=array_ops.identity([0]),
699        dense_shape=(3,))
700    with distribution.scope():
701      self.evaluate(v1.scatter_add(value))
702      self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value()))
703
704      self.evaluate(v2.scatter_min(value))
705      self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value()))
706
707
708class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):
709
710  @combinations.generate(strategy_and_run_tf_function_combinations())
711  def testAssign(self, distribution, experimental_run_tf_function):
712
713    def assign(fn, v, update_value, cross_replica):
714      update_fn = lambda: getattr(v, fn)(update_value)
715      if cross_replica:
716        return update_fn()
717      else:
718        if experimental_run_tf_function:
719          update_fn = def_function.function(update_fn)
720        return test_util.gather(distribution, distribution.run(update_fn))
721
722    updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
723    aggregations = [
724        variables_lib.VariableAggregation.NONE,
725        variables_lib.VariableAggregation.SUM,
726        variables_lib.VariableAggregation.MEAN,
727        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
728    ]
729    options = list(
730        x for x in itertools.product(updates, aggregations, [True, False]))
731    for update, aggregation, cross_replica in options:
732      # VariableAggregation.SUM in cross-replica mode is tested below,
733      # VariableAggregation.NONE in cross-replica mode is not supported.
734      if cross_replica and aggregation in [
735          variables_lib.VariableAggregation.SUM,
736          variables_lib.VariableAggregation.NONE,
737      ]:
738        continue
739      with distribution.scope():
740        v = variable_scope.variable(
741            0.,
742            synchronization=variables_lib.VariableSynchronization.ON_READ,
743            aggregation=aggregation)
744      self.evaluate(variables_lib.global_variables_initializer())
745      fn, update_value = update
746      self.evaluate(assign(fn, v, update_value, cross_replica))
747      for component in v._values:
748        self.assertAllEqual(self.evaluate(component.read_value()),
749                            self.evaluate(array_ops.ones_like(component)))
750
751  @combinations.generate(strategy_and_run_tf_function_combinations())
752  def testAssignOnReadVar(self, distribution, experimental_run_tf_function):
753
754    with distribution.scope():
755      v_to_assign = variable_scope.variable(
756          2., aggregation=variables_lib.VariableAggregation.MEAN)
757      v_to_assign_sub = variable_scope.variable(
758          -2., aggregation=variables_lib.VariableAggregation.MEAN)
759
760    def assign(fn, v, update_value, cross_replica):
761      update_fn = lambda: getattr(v, fn)(update_value)
762      if cross_replica:
763        return update_fn()
764      else:
765        if experimental_run_tf_function:
766          update_fn = def_function.function(update_fn)
767        return test_util.gather(distribution, distribution.run(update_fn))
768
769    updates = [("assign", v_to_assign), ("assign_add", v_to_assign),
770               ("assign_sub", v_to_assign_sub)]
771    expected_cross_replica = {
772        variables_lib.VariableAggregation.SUM: 1.0,
773        variables_lib.VariableAggregation.MEAN: 2.0,
774        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
775    }
776    expected_replica = {
777        variables_lib.VariableAggregation.SUM: 2.0,
778        variables_lib.VariableAggregation.MEAN: 2.0,
779        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0
780    }
781    # aggregation=NONE is not supported for OnReadVariables.
782    aggregations = [
783        variables_lib.VariableAggregation.SUM,
784        variables_lib.VariableAggregation.MEAN,
785        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
786    ]
787    options = list(
788        x for x in itertools.product(updates, aggregations, [True, False]))
789    for update, aggregation, cross_replica in options:
790      # assign in replica context with SUM does not make sense cause you can
791      # just do value * num replicas error is 1. is not a distributed value and
792      # is unsupported for aggregation SUM
793      if aggregation == variables_lib.VariableAggregation.SUM:
794        continue
795      with distribution.scope():
796        v = variable_scope.variable(
797            0.,
798            aggregation=aggregation)
799      self.evaluate(variables_lib.global_variables_initializer())
800      fn, update_value = update
801      self.evaluate(assign(fn, v, update_value, cross_replica))
802      if cross_replica:
803        for component in v._values:
804          self.assertAllEqual(expected_cross_replica.get(aggregation),
805                              self.evaluate(component.read_value()))
806      else:
807        for component in v._values:
808          self.assertAllEqual(expected_replica.get(aggregation),
809                              self.evaluate(component.read_value()))
810
811  @combinations.generate(strategy_and_run_tf_function_combinations())
812  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
813
814    if strategy_test_lib.is_tpu_strategy(distribution):
815      self.skipTest("Assigning PerReplica values is not supported. See"
816                    " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.")
817
818    self.skipTest("We don't support assiging PerReplica values in cross "
819                  "replica context or replica context. see error in "
820                  "sponge/2b2e54c1-eda6-4534-82e1-c73b1dcd517f.")
821
822    with distribution.scope():
823      per_replica_value = values.PerReplica(
824          [constant_op.constant(2.0),
825           constant_op.constant(2.0)])
826
827    def assign(fn, v, update_value, cross_replica):
828      update_fn = lambda: getattr(v, fn)(update_value)
829      if cross_replica:
830        return update_fn()
831      else:
832        if experimental_run_tf_function:
833          update_fn = def_function.function(update_fn)
834        return test_util.gather(distribution, distribution.run(update_fn))
835
836    updates = [("assign", per_replica_value)]
837    # We don't support assigning PerReplica valus to vars in replica context
838    # with aggregation=NONE.
839    aggregations = [
840        variables_lib.VariableAggregation.SUM,
841        variables_lib.VariableAggregation.MEAN,
842        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
843    ]
844    options = list(
845        x for x in itertools.product(updates, aggregations, [True, False]))
846    for update, aggregation, cross_replica in options:
847      # assign in replica context with SUM does not make sense cause you can
848      # just do value * num replicas error is 1. is not a distributed value and
849      # is unsupported for aggregation SUM
850      with distribution.scope():
851        v = variable_scope.variable(
852            0.,
853            synchronization=variables_lib.VariableSynchronization.ON_READ,
854            aggregation=aggregation)
855      self.evaluate(variables_lib.global_variables_initializer())
856      fn, update_value = update
857      # with self.assertRaisesRegex(ValueError, "Attempt to convert a value "):
858      self.evaluate(assign(fn, v, update_value, cross_replica))
859      if aggregation == variables_lib.VariableAggregation.SUM:
860        expected = 4.0
861      else:
862        expected = 2.0
863      for component in v._values:
864        self.assertAllEqual(expected, self.evaluate(component.read_value()))
865
866  @combinations.generate(strategy_and_run_tf_function_combinations())
867  def testAssignDtypeConversion(self, distribution,
868                                experimental_run_tf_function):
869
870    def assign(fn, v, update_value, cross_replica):
871      update_fn = lambda: getattr(v, fn)(update_value)
872      if cross_replica:
873        return update_fn()
874      else:
875        if experimental_run_tf_function:
876          update_fn = def_function.function(update_fn)
877        return test_util.gather(distribution, distribution.run(update_fn))
878
879    updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
880    aggregations = [
881        variables_lib.VariableAggregation.NONE,
882        variables_lib.VariableAggregation.SUM,
883        variables_lib.VariableAggregation.MEAN,
884        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
885    ]
886    options = list(
887        x for x in itertools.product(updates, aggregations, [True, False]))
888    for update, aggregation, cross_replica in options:
889      # VariableAggregation.SUM in cross-replica mode is tested below,
890      # VariableAggregation.NONE in cross-replica mode is not supported.
891      if cross_replica and aggregation in [
892          variables_lib.VariableAggregation.SUM,
893          variables_lib.VariableAggregation.NONE,
894      ]:
895        continue
896      with distribution.scope():
897        v = variable_scope.variable(
898            0.,
899            synchronization=variables_lib.VariableSynchronization.ON_READ,
900            aggregation=aggregation)
901      self.evaluate(variables_lib.global_variables_initializer())
902      fn, update_value = update
903      self.evaluate(assign(fn, v, update_value, cross_replica))
904      for component in v._values:
905        self.assertAllEqual(self.evaluate(component.read_value()),
906                            self.evaluate(array_ops.ones_like(component)))
907
908  @combinations.generate(strategy_with_var_policy())
909  def testAssignWithAggregationSum(self, distribution):
910    with distribution.scope():
911      v = variable_scope.variable(
912          0.,
913          synchronization=variables_lib.VariableSynchronization.ON_READ,
914          aggregation=variables_lib.VariableAggregation.SUM)
915    self.evaluate(variables_lib.global_variables_initializer())
916    self.evaluate(v.assign(1. * distribution.num_replicas_in_sync))
917    for component in v._values:
918      self.assertAllEqual(self.evaluate(component.read_value()),
919                          self.evaluate(array_ops.ones_like(component)))
920
921  @combinations.generate(strategy_with_var_policy())
922  def testAssignAddSubWithAggregationSum(self, distribution):
923    with distribution.scope():
924      v = variable_scope.variable(
925          0.,
926          synchronization=variables_lib.VariableSynchronization.ON_READ,
927          aggregation=variables_lib.VariableAggregation.SUM)
928    self.evaluate(variables_lib.global_variables_initializer())
929    with self.assertRaisesRegex(
930        ValueError, "SyncOnReadVariable does not support "):
931      self.evaluate(v.assign_add(1.))
932    with self.assertRaisesRegex(
933        ValueError, "SyncOnReadVariable does not support "):
934      self.evaluate(v.assign_sub(1.))
935
936  @combinations.generate(strategy_and_run_tf_function_combinations())
937  def testReadValueInReplicaContext(self, distribution,
938                                    experimental_run_tf_function):
939    aggregations = [
940        variables_lib.VariableAggregation.NONE,
941        variables_lib.VariableAggregation.SUM,
942        variables_lib.VariableAggregation.MEAN,
943        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
944    ]
945    for aggregation in aggregations:
946      with distribution.scope():
947        v = variable_scope.variable(
948            0.,
949            synchronization=variables_lib.VariableSynchronization.ON_READ,
950            aggregation=aggregation)
951      self.evaluate(variables_lib.global_variables_initializer())
952      if experimental_run_tf_function:
953        read_var_fn = def_function.function(v.read_value)
954      else:
955        read_var_fn = v.read_value
956      results = self.evaluate(
957          test_util.gather(distribution, distribution.run(read_var_fn)))
958      for component, value in zip(v._values, results):
959        self.assertAllEqual(self.evaluate(component.read_value()), value)
960
961  @combinations.generate(strategy_and_run_tf_function_combinations())
962  def testReadValueInCrossReplicaContext(self, distribution,
963                                         experimental_run_tf_function):
964    aggregations = [
965        variables_lib.VariableAggregation.SUM,
966        variables_lib.VariableAggregation.MEAN,
967        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
968    ]
969    for aggregation in aggregations:
970      if strategy_test_lib.is_tpu_strategy(distribution):
971        resolver = tpu_cluster_resolver.TPUClusterResolver("")
972        tpu_strategy_util.initialize_tpu_system(resolver)
973      with distribution.scope():
974        v = variable_scope.variable(
975            0.,
976            synchronization=variables_lib.VariableSynchronization.ON_READ,
977            aggregation=aggregation)
978      self.evaluate(variables_lib.global_variables_initializer())
979
980      def assign(v=v):
981        ctx = ds_context.get_replica_context()
982        replica_id = ctx.replica_id_in_sync_group
983        return v.assign(math_ops.cast(replica_id, dtypes.float32))
984
985      if experimental_run_tf_function:
986        assign = def_function.function(assign)
987
988      self.evaluate(test_util.gather(distribution, distribution.run(assign)))
989      num_replicas = distribution.num_replicas_in_sync
990      sum_of_replica_values = num_replicas * (num_replicas - 1) / 2.
991      if aggregation == variables_lib.VariableAggregation.SUM:
992        expected = sum_of_replica_values
993      elif aggregation == variables_lib.VariableAggregation.MEAN:
994        expected = sum_of_replica_values / num_replicas
995      else:
996        expected = 0
997      self.assertEqual(expected, self.evaluate(v.read_value()), aggregation)
998      self.assertEqual(expected, self.evaluate(v.value()), aggregation)
999      self.assertEqual(expected, self.evaluate(v), aggregation)
1000      self.assertEqual(expected, self.evaluate(array_ops.identity(v)),
1001                       aggregation)
1002
1003  @combinations.generate(strategy_and_run_tf_function_combinations())
1004  def testAllReduce(self, distribution, experimental_run_tf_function):
1005    with distribution.scope():
1006      v = variable_scope.variable(
1007          2.,
1008          synchronization=variables_lib.VariableSynchronization.ON_WRITE,
1009          aggregation=variables_lib.VariableAggregation.MEAN)
1010    self.evaluate(variables_lib.global_variables_initializer())
1011
1012    def all_reduce():
1013      ctx = ds_context.get_replica_context()
1014      replica_id = ctx.replica_id_in_sync_group
1015      return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
1016                                                      dtypes.float32)
1017
1018    if experimental_run_tf_function:
1019      all_reduce = def_function.function(all_reduce)
1020
1021    per_replica_results = self.evaluate(
1022        test_util.gather(distribution, distribution.run(all_reduce)))
1023    expected_result = []
1024    for i in range(distribution.num_replicas_in_sync):
1025      expected_result.append(2.0 * distribution.num_replicas_in_sync +
1026                             1.0 * i)
1027    self.assertAllEqual(per_replica_results, tuple(expected_result))
1028
1029  @combinations.generate(strategy_and_run_tf_function_combinations())
1030  def testAssignPerReplicaBeforeRead(self, distribution,
1031                                     experimental_run_tf_function):
1032    aggregations = [
1033        variables_lib.VariableAggregation.SUM,
1034        variables_lib.VariableAggregation.MEAN,
1035        variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
1036    ]
1037    for aggregation in aggregations:
1038      with distribution.scope():
1039        v = variable_scope.variable(
1040            0.,
1041            synchronization=variables_lib.VariableSynchronization.ON_READ,
1042            aggregation=aggregation)
1043      self.evaluate(variables_lib.global_variables_initializer())
1044
1045      def assign(var=v):
1046        ctx = ds_context.get_replica_context()
1047        replica_id = ctx.replica_id_in_sync_group
1048        return var.assign(math_ops.cast(replica_id, dtypes.float32))
1049
1050      if experimental_run_tf_function:
1051        assign = def_function.function(assign)
1052
1053      per_replica_results = self.evaluate(
1054          test_util.gather(distribution, distribution.run(assign)))
1055      expected_result = []
1056      for i in range(distribution.num_replicas_in_sync):
1057        expected_result.append(1.0 * i)
1058      self.assertAllEqual(per_replica_results, tuple(expected_result))
1059
1060  @combinations.generate(strategy_with_var_policy())
1061  def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
1062    with distribution.scope():
1063      v = variable_scope.variable(
1064          0.,
1065          synchronization=variables_lib.VariableSynchronization.ON_READ,
1066          aggregation=variables_lib.VariableAggregation.NONE)
1067    self.evaluate(variables_lib.global_variables_initializer())
1068    with self.assertRaisesRegex(
1069        ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
1070      self.evaluate(v.read_value())
1071
1072  @combinations.generate(strategy_with_var_policy())
1073  def testInitializedToSameValueInsideEagerRun(self, distribution):
1074    if not context.executing_eagerly(): self.skipTest("eager only")
1075    if isinstance(distribution.extended,
1076                  collective_all_reduce_strategy.CollectiveAllReduceExtended):
1077      self.skipTest("Test for more than 1 device per worker only.")
1078
1079    v = [None]
1080    @def_function.function
1081    def step():
1082      def f():
1083        if v[0] is None:
1084          v[0] = variables_lib.Variable(
1085              random_ops.random_normal([]),
1086              synchronization=variables_lib.VariableSynchronization.ON_READ)
1087
1088      distribution.run(f)
1089
1090    context.set_global_seed(None)
1091    step()
1092    vals = self.evaluate(v[0].values)
1093    self.assertAllEqual(vals[0], vals[1])
1094
1095  @combinations.generate(strategy_with_var_policy())
1096  def testOperatorOverride(self, distribution):
1097
1098    with distribution.scope():
1099      v = variable_scope.variable(
1100          0.0,
1101          synchronization=variables_lib.VariableSynchronization.ON_READ,
1102          aggregation=variables_lib.VariableAggregation.MEAN)
1103      self.evaluate(variables_lib.global_variables_initializer())
1104
1105      @def_function.function
1106      def assign():
1107        ctx = ds_context.get_replica_context()
1108        replica_id = ctx.replica_id_in_sync_group
1109        return v.assign(math_ops.cast(replica_id, dtypes.float32))
1110
1111      # Assign different replicas with different values.
1112      self.evaluate(test_util.gather(distribution, distribution.run(assign)))
1113      self.assertEqual(1.5, self.evaluate(v + 1))
1114
1115      @def_function.function
1116      def add():
1117        return v + 1
1118
1119      per_replica_results = self.evaluate(
1120          test_util.gather(distribution, distribution.run(add)))
1121      self.assertAllEqual([1, 2], per_replica_results)
1122
1123  @combinations.generate(
1124      combinations.combine(
1125          strategy=[
1126              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1127              strategy_combinations.tpu_strategy,
1128              strategy_combinations.tpu_strategy_packed_var,
1129              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1130              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1131          ],
1132          mode=["eager"],
1133          use_var_policy=[True, False]))
1134  def testSaveAndRestoreOnRead(self, strategy):
1135    aggregation = [variable_scope.VariableAggregation.SUM,
1136                   variable_scope.VariableAggregation.MEAN]
1137    for agg in aggregation:
1138      v_normal_restore = variables_lib.Variable(1.0)
1139      v_normal_save = variables_lib.Variable(2.0)
1140
1141      with strategy.scope():
1142        v_on_read = variables_lib.Variable(
1143            1.0, synchronization=variable_scope.VariableSynchronization.ON_READ,
1144            aggregation=agg)
1145
1146        @def_function.function
1147        def assign_fn():
1148          cluster_resolver = strategy.cluster_resolver
1149          replica_ctx = ds_context.get_replica_context()
1150          if ((cluster_resolver and cluster_resolver.task_type == "worker") or
1151              math_ops.equal(replica_ctx.replica_id_in_sync_group,
1152                             constant_op.constant(1))):
1153            v_on_read.assign(3.)  # pylint:disable=cell-var-from-loop
1154          else:
1155            v_on_read.assign(4.)  # pylint:disable=cell-var-from-loop
1156
1157        strategy.run(assign_fn)
1158
1159        # Save ONREAD, restore ONREAD
1160        # Saves v[0] + v[1] = 7 for SUM and 3.5 for MEAN.
1161        ckpt = trackable_utils.Checkpoint(var=v_on_read)
1162        manager = ckpt_manager.CheckpointManager(
1163            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
1164        manager.save()
1165        # Restores a value of 7/2 = 3.5 for SUM and 3.5 for MEAN.
1166        ckpt.restore(manager.latest_checkpoint)
1167        self.assertEqual(3.5, self.evaluate(v_on_read._values[0]))
1168
1169        # Save ONREAD, restore normal
1170        ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
1171        ckpt_normal.restore(manager.latest_checkpoint)
1172        if agg == variable_scope.VariableAggregation.SUM:
1173          self.assertEqual(7.0, self.evaluate(v_normal_restore.read_value()))
1174        else:
1175          self.assertEqual(3.5, self.evaluate(v_normal_restore.read_value()))
1176
1177        # Save normal, restore ONREAD
1178        ckpt = trackable_utils.Checkpoint(var=v_normal_save)
1179        manager = ckpt_manager.CheckpointManager(
1180            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
1181        manager.save()
1182        # Restores a value of 2/2 = 1.0 for SUM and 2.0 for MEAN.
1183        ckpt_on_read = trackable_utils.Checkpoint(var=v_on_read)
1184        ckpt_on_read.restore(manager.latest_checkpoint)
1185        if agg == variable_scope.VariableAggregation.SUM:
1186          self.assertEqual(1.0, self.evaluate(v_on_read._values[0]))
1187        else:
1188          self.assertEqual(2.0, self.evaluate(v_on_read._values[0]))
1189
1190
1191@combinations.generate(
1192    combinations.combine(
1193        distribution=[
1194            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1195            strategy_combinations.multi_worker_mirrored_2x1_cpu,
1196            strategy_combinations.multi_worker_mirrored_2x1_gpu,
1197        ],
1198        aggregation=[
1199            variables_lib.VariableAggregation.MEAN,
1200            variables_lib.VariableAggregation.SUM,
1201            variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
1202        ],
1203        mode=["graph", "eager"],
1204        use_var_policy=[True, False]))
1205class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
1206
1207  def testScatterSub(self, distribution, aggregation):
1208    with distribution.scope():
1209      v = variables_lib.Variable(
1210          [1., 1., 1.],
1211          synchronization=variables_lib.VariableSynchronization.ON_READ,
1212          aggregation=aggregation)
1213    self.evaluate(v.initializer)
1214
1215    delta = values.PerReplica([
1216        indexed_slices.IndexedSlices(
1217            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
1218        indexed_slices.IndexedSlices(
1219            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
1220    ])
1221
1222    with self.assertRaises(NotImplementedError):
1223      self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
1224
1225  def testScatterAdd(self, distribution, aggregation):
1226    with distribution.scope():
1227      v = variables_lib.Variable(
1228          [1., 1., 1.],
1229          synchronization=variables_lib.VariableSynchronization.ON_READ,
1230          aggregation=aggregation)
1231    self.evaluate(v.initializer)
1232
1233    delta = values.PerReplica([
1234        indexed_slices.IndexedSlices(
1235            values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)),
1236        indexed_slices.IndexedSlices(
1237            values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)),
1238    ])
1239
1240    with self.assertRaises(NotImplementedError):
1241      self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
1242
1243  def testScatterDiv(self, distribution, aggregation):
1244    with distribution.scope():
1245      v = variables_lib.Variable(
1246          [2., 6., 1.],
1247          synchronization=variables_lib.VariableSynchronization.ON_READ,
1248          aggregation=aggregation)
1249    self.evaluate(v.initializer)
1250
1251    delta = values.PerReplica([
1252        indexed_slices.IndexedSlices(
1253            values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)),
1254        indexed_slices.IndexedSlices(
1255            values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)),
1256    ])
1257
1258    with self.assertRaises(NotImplementedError):
1259      self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
1260
1261  def testScatterMul(self, distribution, aggregation):
1262    with distribution.scope():
1263      v = variables_lib.Variable(
1264          [2., 1., 1.],
1265          synchronization=variables_lib.VariableSynchronization.ON_READ,
1266          aggregation=aggregation)
1267    self.evaluate(v.initializer)
1268
1269    delta = values.PerReplica([
1270        indexed_slices.IndexedSlices(
1271            values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)),
1272        indexed_slices.IndexedSlices(
1273            values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)),
1274    ])
1275
1276    with self.assertRaises(NotImplementedError):
1277      self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
1278
1279  def testScatterMin(self, distribution, aggregation):
1280    with distribution.scope():
1281      v = variables_lib.Variable(
1282          [3., 4., 5.],
1283          synchronization=variables_lib.VariableSynchronization.ON_READ,
1284          aggregation=aggregation)
1285    self.evaluate(v.initializer)
1286
1287    delta = values.PerReplica([
1288        indexed_slices.IndexedSlices(
1289            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
1290        indexed_slices.IndexedSlices(
1291            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
1292    ])
1293
1294    with self.assertRaises(NotImplementedError):
1295      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
1296
1297  def testScatterMax(self, distribution, aggregation):
1298    with distribution.scope():
1299      v = variables_lib.Variable(
1300          [3., 4., 5.],
1301          synchronization=variables_lib.VariableSynchronization.ON_READ,
1302          aggregation=aggregation)
1303    self.evaluate(v.initializer)
1304
1305    delta = values.PerReplica([
1306        indexed_slices.IndexedSlices(
1307            values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)),
1308        indexed_slices.IndexedSlices(
1309            values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)),
1310    ])
1311
1312    with self.assertRaises(NotImplementedError):
1313      self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
1314
1315  def testScatterUpdate(self, distribution, aggregation):
1316    with distribution.scope():
1317      v = variables_lib.Variable(
1318          [0., 0., 0.],
1319          synchronization=variables_lib.VariableSynchronization.ON_READ,
1320          aggregation=aggregation)
1321    self.evaluate(v.initializer)
1322
1323    delta = values.PerReplica([
1324        indexed_slices.IndexedSlices(
1325            values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)),
1326        indexed_slices.IndexedSlices(
1327            values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)),
1328    ])
1329
1330    with self.assertRaises(NotImplementedError):
1331      self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
1332
1333
1334if __name__ == "__main__":
1335  test_util.main()
1336