xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/tpu_strategy_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPUStrategy."""
16
17from absl import logging
18from absl.testing import parameterized
19
20from tensorflow.core.protobuf import config_pb2
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.distribute import distribution_strategy_context
24from tensorflow.python.distribute import reduce_util
25from tensorflow.python.distribute import strategy_test_lib
26from tensorflow.python.distribute import tpu_strategy as tpu_lib
27from tensorflow.python.distribute import tpu_values
28from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.eager import function
32from tensorflow.python.eager import remote
33from tensorflow.python.eager import test
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import config
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import device as tf_device
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import sparse_tensor
42from tensorflow.python.framework import tensor_spec
43from tensorflow.python.framework import test_util
44from tensorflow.python.framework import type_spec
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import control_flow_ops
47from tensorflow.python.ops import embedding_ops
48from tensorflow.python.ops import gen_dataset_ops
49from tensorflow.python.ops import logging_ops
50from tensorflow.python.ops import lookup_ops
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import random_ops
53from tensorflow.python.ops import string_ops
54from tensorflow.python.ops import variables
55from tensorflow.python.ops.ragged import ragged_tensor
56from tensorflow.python.platform import flags
57from tensorflow.python.platform import tf_logging as logging
58from tensorflow.python.tpu import device_assignment as device_assignment_lib
59from tensorflow.python.tpu import tpu
60from tensorflow.python.tpu import tpu_hardware_feature
61from tensorflow.python.tpu import tpu_strategy_util
62from tensorflow.python.training import server_lib
63from tensorflow.python.util import nest
64
65
66FLAGS = flags.FLAGS
67flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
68flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
69flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
70
71
72def get_tpu_cluster_resolver():
73  resolver = tpu_cluster_resolver.TPUClusterResolver(
74      tpu=FLAGS.tpu,
75      zone=FLAGS.zone,
76      project=FLAGS.project,
77  )
78  return resolver
79
80
81def get_tpu_strategy(enable_packed_var=False):
82  resolver = get_tpu_cluster_resolver()
83  remote.connect_to_cluster(resolver)
84  tpu_strategy_util.initialize_tpu_system(resolver)
85  strategy = tpu_lib.TPUStrategyV2(resolver)
86  strategy._enable_packed_variable_in_eager_mode = enable_packed_var
87  return strategy
88
89
90# TPU tests which don't use TPUStrategy.
91@test_util.with_eager_op_as_function
92class TPUTest(test.TestCase):
93
94  # In this case, the entire computation in foo is compiled using JIT
95  # compilation.
96  def test_single_tpu_jit_compile(self):
97    with ops.device("/device:TPU:0"):
98      a = variables.Variable(1)
99
100    def get_a_plus_one():
101      return a + 1
102
103    @def_function.function(
104        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
105    def foo(x):
106      b = x + get_a_plus_one()
107      b = b + get_a_plus_one()
108      return b + 1
109
110    with ops.device("/device:TPU:0"):
111      result = foo(a)
112    self.assertAllEqual(6, result)
113
114  # In this case, the entire computation in foo is compiled using JIT
115  # compilation and contains unsupported ops that should be outside compiled.
116  def test_single_tpu_jit_compile_with_outside_compilation(self):
117    context.enable_jit_compile_rewrite()
118    get_tpu_strategy(True)
119    config.set_soft_device_placement(True)
120    with ops.device("/device:TPU:0"):
121      a = variables.Variable(1)
122
123    def get_a_plus_one():
124      return a + 1
125
126    @def_function.function(
127        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
128    def foo(x):
129      b = x + get_a_plus_one()
130      my_str = string_ops.as_string(b)
131      new_str = my_str + "0"
132      c = string_ops.string_to_number(new_str, out_type=dtypes.int32)
133      logging_ops.print_v2(c)
134      b = c + get_a_plus_one()
135      return b + 1
136
137    with ops.device("/device:TPU:0"):
138      result = foo(a)
139    self.assertAllEqual(33, result)
140
141  # In this case, each of the ops in the TPU device scope are compiled and run
142  # individually.
143  def test_single_tpu_on_demand(self):
144    with ops.device("/device:TPU:0"):
145      a = variables.Variable(1)
146
147    def get_a_plus_one():
148      return a + 1
149
150    x = 1
151    with ops.device("/device:TPU:0"):
152      b = x + get_a_plus_one()
153      b = b + get_a_plus_one()
154    result = b + 1
155
156    self.assertAllEqual(6, result)
157
158  # In this case, each of the ops in the tf.function and TPU device scope are
159  # compiled and run individually.
160  def test_single_tpu_on_demand_tf_function(self):
161    with ops.device("/device:TPU:0"):
162      a = variables.Variable(1)
163
164    def get_a_plus_one():
165      return a + 1
166
167    @def_function.function(
168        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
169    def foo(x):
170      with ops.device("/device:TPU:0"):
171        b = x + get_a_plus_one()
172        b = b + get_a_plus_one()
173      return b + 1
174
175    result = foo(a)
176    self.assertAllEqual(6, result)
177
178  def test_multiple_initialize_system(self):
179    resolver = get_tpu_cluster_resolver()
180    remote.connect_to_cluster(resolver)
181    tpu_strategy_util.initialize_tpu_system(resolver)
182
183    with test.mock.patch.object(logging, "warning") as mock_log:
184      tpu_strategy_util.initialize_tpu_system(resolver)
185      self.assertRegex(str(mock_log.call_args), "already been initialized")
186
187  def test_tpu_tf_function_same_device(self):
188    with ops.device("/device:TPU:0"):
189      a = variables.Variable(1)
190
191    @function.defun_with_attributes(attributes={"_noinline": True})
192    def get_a_plus_one():
193      return a + 1
194
195    @def_function.function(
196        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
197    def foo(x):
198      with ops.device("/device:TPU:0"):
199        b = x + get_a_plus_one()
200      return b + 1
201
202    result = foo(a)
203    self.assertAllEqual(4, result)
204
205  def test_tpu_return_int32(self):
206    with ops.device("/device:TPU:0"):
207      a = variables.Variable(0)
208
209    @def_function.function
210    def foo():
211      return a + 1
212
213    @def_function.function
214    def bar():
215      with ops.device("/device:TPU:1"):
216        return foo()
217
218    with ops.device("/device:CPU:0"):
219      result = bar() + 1
220      self.assertAllEqual(result, 2)
221
222  def test_tpu_output_device(self):
223
224    def foo():
225      return 1 + 1
226
227    func1 = function.defun_with_attributes(
228        foo, attributes={"_XlaMustCompile": False})
229    func2 = function.defun_with_attributes(
230        foo, attributes={
231            "_OutputsOnOpDevice": True,
232            "_XlaMustCompile": False
233        })
234
235    with ops.device("/device:TPU:0"):
236      ret1 = func1()
237      ret2 = func2()
238
239    self.assertAllEqual(ret1.backing_device,
240                        "/job:localhost/replica:0/task:0/device:CPU:0")
241    self.assertAllEqual(ret2.backing_device,
242                        "/job:localhost/replica:0/task:0/device:TPU:0")
243
244  def test_on_demand_op_with_dynamic_output(self):
245    with ops.device("/device:TPU:0"):
246      where_output = array_ops.where([True, False, True])
247    self.assertAllEqual(where_output, [[0], [2]])
248
249    with ops.device("/device:TPU:0"):
250      repeat_output = array_ops.repeat(math_ops.range(2), [1, 4])
251    self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
252
253
254@parameterized.named_parameters([("PackedVar", True), ("", False)])
255@test_util.with_eager_op_as_function
256class TPUStrategyTest(test.TestCase, parameterized.TestCase):
257
258  def test_handle_in_cross_replica_context(self, enable_packed_var):
259    strategy = get_tpu_strategy(enable_packed_var)
260    with strategy.scope():
261      v = variables.Variable(1.0)
262
263    @def_function.function
264    def func():
265      self.assertEndsWith(v.handle.device, "device:TPU:0")
266      return v + 1.0
267
268    ret = func()
269    self.assertAllEqual(ret, 2.0)
270
271  def testStaticHashTableDatasetFnHostTrainingLoop(self, enable_packed_var):
272    self._dataset_fn_tracing_count = 0
273    strategy = get_tpu_strategy(enable_packed_var)
274
275    with strategy.scope():
276      vals = [0, 1, 2]
277      keys_tensor = constant_op.constant(
278          list(range(len(vals))), dtype=dtypes.int64)
279      vals_tensor = constant_op.constant(vals)
280      initializer = lookup_ops.KeyValueTensorInitializer(
281          keys_tensor, vals_tensor)
282      per_worker_table = lookup_ops.StaticHashTable(
283          initializer, default_value=-1)
284
285    @def_function.function
286    def dataset_fn(input_context):
287      tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64)
288      global_batch_size = 2
289      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
290      dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
291          batch_size, drop_remainder=True)
292      dataset = dataset.shard(input_context.num_input_pipelines,
293                              input_context.input_pipeline_id)
294      dataset = dataset.prefetch(2)  # This prefetches 2 batches per device.
295      dataset = dataset.map(per_worker_table.lookup)
296      self._dataset_fn_tracing_count += 1
297      return dataset
298
299    dist_iterator = iter(
300        strategy.experimental_distribute_datasets_from_function(dataset_fn))
301
302    @def_function.function
303    def step_fn(inputs):
304      # inputs should be [0, 1, -1]
305      return math_ops.reduce_sum(inputs)
306
307    def train_steps(iterator, steps):
308
309      for _ in math_ops.range(steps):
310        strategy.run(step_fn, args=(next(iterator),))
311
312    train_steps(dist_iterator, steps=5)
313    self.assertEqual(self._dataset_fn_tracing_count, 1)
314
315  def test_function_compile_with_xla(self, enable_packed_var):
316    if FLAGS.tpu_use_tfrt:
317      self.skipTest(
318          "This test triggers _XlaCompile and XlaLaunch which are not "
319          "supported in tfrt yet. We should avoid using these kernels on TPU. "
320          "However, it is a workaround to support b/129842431. We need more "
321          "discussion about how to support it in the long term.")
322    strategy = get_tpu_strategy(enable_packed_var)
323    with strategy.scope():
324      v = variables.Variable(1.0)
325
326    @def_function.function
327    def func():
328      return v.read_value() + 1.0
329
330    with ops.device("/device:TPU:0"):
331      self.assertAllEqual(func(), 2.0)
332
333  def test_sequential_runs(self, enable_packed_var):
334    resolver = get_tpu_cluster_resolver()
335    remote.connect_to_cluster(resolver)
336    topology = tpu_strategy_util.initialize_tpu_system(resolver)
337    # Computation replicated to all cores.
338    device_assignment = device_assignment_lib.DeviceAssignment.build(
339        topology, num_replicas=2)
340    strategy = tpu_lib.TPUStrategyV2(
341        resolver, experimental_device_assignment=device_assignment)
342    strategy._enable_packed_variable_in_eager_mode = enable_packed_var
343
344    # Computation on the 1st core.
345    device_assignment2 = device_assignment_lib.DeviceAssignment.build(
346        topology, num_replicas=1)
347    strategy2 = tpu_lib.TPUStrategyV2(
348        resolver, experimental_device_assignment=device_assignment2)
349
350    def computation(x):
351      return math_ops.square(x)
352
353    @def_function.function
354    def train_step():
355      outputs = strategy.experimental_local_results(
356          strategy.run(computation, args=([2., 2.],)))
357      outputs2 = strategy2.run(
358          computation, args=([outputs[0]],))
359      return outputs2
360
361    self.assertAllEqual([[16., 16.]], train_step())
362
363  def test_device_switch_case(self, enable_packed_var):
364    strategy = get_tpu_strategy(enable_packed_var)
365    with strategy.scope():
366      a = variables.Variable(1)
367
368    inference_iteration = variables.Variable(-1)
369
370    def inference_fn(x, i):
371      return a + x + i
372
373    @def_function.function
374    def run_inference(x):
375
376      def do_inference(device, inference_fn, i):
377        with ops.device(device):
378          return inference_fn(x, i)
379
380      branch_fns = {
381          0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)),
382          1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)),
383      }
384      branch_index = inference_iteration.assign_add(1, use_locking=True) % 2
385      return control_flow_ops.switch_case(branch_index, branch_fns)
386
387    self.assertAllEqual(2., run_inference(1))  # Use TPU core 0.
388    self.assertAllEqual(3., run_inference(1))  # Use TPU core 1.
389
390  def test_recover_from_compilation_failures(self, enable_packed_var):
391    # TODO(b/148150981): Stop skipping this test once recovery works
392    # for non-local TPU.
393    if FLAGS.tpu:
394      self.skipTest("Recovery fails for non-local TPU, see b/148150981")
395
396    # Disable automatic outside compilation.
397    config.set_soft_device_placement(False)
398    strategy = get_tpu_strategy(enable_packed_var)
399
400    @def_function.function
401    def compilation_failure_run():
402
403      def computation():
404        return random_ops.random_gamma([10], [0.5, 1.5])
405
406      return strategy.run(computation)
407
408    with self.assertRaises(errors.OpError):
409      compilation_failure_run()
410
411    @def_function.function
412    def good_run():
413
414      def computation():
415        return random_ops.random_normal([10])
416
417      return strategy.run(computation)
418
419    good_run()
420
421  def test_dynamic_shape_with_outside_compilation_failure(
422      self, enable_packed_var):
423    # Enable automatic outside compilation.
424    config.set_soft_device_placement(True)
425    strategy = get_tpu_strategy(enable_packed_var)
426    dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
427        2, drop_remainder=False)
428    dataset = strategy.experimental_distribute_dataset(dataset)
429    iterator = iter(dataset)
430
431    @def_function.function
432    def train_fn(iterator):
433
434      def step_fn(inputs):
435        input0, input1 = inputs
436        return array_ops.size(input0), math_ops.reduce_sum(input1)
437
438      return strategy.experimental_local_results(
439          strategy.run(step_fn, args=(next(iterator),)))
440
441    with self.assertRaises(errors.InvalidArgumentError):
442      logging.info(train_fn(iterator))
443
444  def test_computation_on_subset_cores(self, enable_packed_var):
445    resolver = get_tpu_cluster_resolver()
446    remote.connect_to_cluster(resolver)
447    topology = tpu_strategy_util.initialize_tpu_system(resolver)
448    all_core_strategy = tpu_lib.TPUStrategyV2(resolver)
449    all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var
450
451    with all_core_strategy.scope():
452      v = variables.Variable(0.0,
453                             aggregation=variables.VariableAggregation.MEAN)
454
455    # Computation on the 1st core.
456    device_assignment = device_assignment_lib.DeviceAssignment.build(
457        topology, num_replicas=1)
458    first_core_strategy = tpu_lib.TPUStrategyV2(
459        resolver, experimental_device_assignment=device_assignment)
460    first_core_strategy._enable_packed_variable_in_eager_mode = (
461        enable_packed_var)
462
463    # Computation on the 2nd core.
464    device_assignment2 = device_assignment_lib.DeviceAssignment(
465        topology, [[[0, 0, 0, 1]]])
466    second_core_strategy = tpu_lib.TPUStrategyV2(
467        resolver, experimental_device_assignment=device_assignment2)
468    second_core_strategy._enable_packed_variable_in_eager_mode = (
469        enable_packed_var)
470
471    @def_function.function
472    def train_step():
473
474      def step_fn():
475        return v + 1.0
476
477      all_core_strategy.run(step_fn)
478      r1 = first_core_strategy.run(step_fn)
479      r2 = second_core_strategy.run(step_fn)
480      return r1 + r2
481
482    train_step()
483    self.assertAllEqual(2., train_step())
484
485  def test_worker_devices_on_subset_cores(self, enable_packed_var):
486    resolver = get_tpu_cluster_resolver()
487    remote.connect_to_cluster(resolver)
488    topology = tpu_strategy_util.initialize_tpu_system(resolver)
489
490    # Strategy for the 1st core.
491    device_assignment = device_assignment_lib.DeviceAssignment.build(
492        topology, num_replicas=1)
493    first_core_strategy = tpu_lib.TPUStrategyV2(
494        resolver, experimental_device_assignment=device_assignment)
495    first_core_strategy._enable_packed_variable_in_eager_mode = (
496        enable_packed_var)
497
498    # Strategy for the 2nd core.
499    device_assignment2 = device_assignment_lib.DeviceAssignment(
500        topology, [[[0, 0, 0, 1]]])
501    second_core_strategy = tpu_lib.TPUStrategyV2(
502        resolver, experimental_device_assignment=device_assignment2)
503    second_core_strategy._enable_packed_variable_in_eager_mode = (
504        enable_packed_var)
505
506    self.assertLen(first_core_strategy.extended.worker_devices, 1)
507    self.assertEndsWith(first_core_strategy.extended.worker_devices[0],
508                        "device:TPU:0")
509
510    self.assertLen(second_core_strategy.extended.worker_devices, 1)
511    self.assertEndsWith(second_core_strategy.extended.worker_devices[0],
512                        "device:TPU:1")
513
514  def test_control_output_in_while_body_fn(self, enable_packed_var):
515    strategy = get_tpu_strategy(enable_packed_var)
516
517    with strategy.scope():
518      v = variables.Variable(
519          0.0, aggregation=variables.VariableAggregation.MEAN)
520
521    @def_function.function
522    def train_step():
523
524      def step_fn():
525        v.assign_add(1)
526
527      for _ in math_ops.range(2):
528        strategy.run(step_fn)
529
530    train_step()
531    self.assertEqual(2.0, v.numpy())
532
533  def test_cluster_conditional_with_dynamic_shape(self, enable_packed_var):
534    strategy = get_tpu_strategy(enable_packed_var)
535
536    @def_function.function
537    def train_step():
538
539      def shape_list(tensor):
540        shape = tensor.shape.as_list()
541
542        non_static_indexes = []
543        for (index, dim) in enumerate(shape):
544          if dim is None:
545            non_static_indexes.append(index)
546
547        if not non_static_indexes:
548          return shape
549
550        dynamic_shape = array_ops.shape(input=tensor)
551        for index in non_static_indexes:
552          shape[index] = dynamic_shape[index]
553
554        return shape
555
556      def step_fn(condition):
557        where = array_ops.where(condition)
558        if array_ops.shape(where)[0] > 0:
559          tensor_shape = shape_list(where)
560          d1 = tensor_shape[0]
561          d2 = tensor_shape[1]
562          where = array_ops.reshape(where, [d1, d2])
563        return where
564
565      return strategy.run(step_fn, args=([True, False, True],))
566
567    outputs = strategy.experimental_local_results(train_step())
568    self.assertAllEqual(outputs[0].numpy(), [[0], [2]])
569
570  def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var):
571    strategy = get_tpu_strategy(enable_packed_var)
572
573    @def_function.function
574    def train_step():
575
576      def step_fn(prev):
577        s = prev + 1
578        return s
579
580      def init_fn():
581        return array_ops.zeros(shape=())
582
583      prev = strategy.run(init_fn)
584      for _ in math_ops.range(10):
585        prev = strategy.run(step_fn, args=(prev,))
586      return strategy.reduce(reduce_util.ReduceOp.SUM, prev, axis=None)
587
588    sum_val = train_step().numpy().astype(float)
589    self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10)
590
591  def test_two_clusters_with_same_fn(self, enable_packed_var):
592    strategy = get_tpu_strategy(enable_packed_var)
593
594    @def_function.function
595    def foo(x):
596      return strategy.run(lambda x: x + 1, (x,))
597
598    @def_function.function
599    def bar(x):
600      foo(x)
601      return foo(x)
602
603    bar(1)
604
605  def test_tpu_variable_run_argument(self, enable_packed_var):
606    # TPUStrategy.run() casts inputs to Tensor, but has logic to preserve
607    # variables to avoid unintuitive errors.
608    # Here we test that a TPUDistributedVariable passed to TPUStrategy.run()
609    # remains a variable.
610
611    strategy = get_tpu_strategy(enable_packed_var)
612
613    with strategy.scope():
614      tpu_variable = variables.Variable(1)
615
616    def replica_step(first_arg, variable):
617      del first_arg  # Just here to make sure we're not relying on arg position.
618
619      if variable is not None:
620        self.assertIsInstance(variable, tpu_values.TPUDistributedVariable)
621
622    @def_function.function
623    def step():
624      strategy.run(
625          replica_step, args=(
626              2,
627              tpu_variable,
628          ))
629
630    step()
631
632  def test_tpu_run_arg_parsing(self, enable_packed_var):
633    strategy = get_tpu_strategy(enable_packed_var)
634
635    with strategy.scope():
636      tpu_vars = [variables.Variable(1)]
637
638    def only_star_args(*args):
639      del args
640
641    def pos_and_star_args(first_arg, *args):
642      del first_arg
643      del args
644
645    def named_args(first_arg, second_arg):
646      del first_arg
647      del second_arg
648
649    def star_args_and_kw_only(*args, kw):
650      del args
651      del kw
652
653    # pylint:disable=function-redefined
654    @def_function.function
655    def step():
656      strategy.run(only_star_args, args=(2,))
657
658    step()
659
660    @def_function.function
661    def step():
662      strategy.run(named_args, kwargs={"first_arg": 2, "second_arg": 3})
663
664    step()
665
666    with self.assertRaisesRegex(TypeError, r"got multiple values for argument"):
667
668      @def_function.function
669      def step():
670        strategy.run(
671            named_args, args=(1,), kwargs={
672                "first_arg": 2,
673                "second_arg": 3
674            })
675
676      step()
677
678    with self.assertRaisesRegex(ValueError,
679                                r"cannot handle Variables passed to \*args"):
680
681      @def_function.function
682      def step():
683        strategy.run(
684            only_star_args, args=(
685                2,
686                tpu_vars,
687            ))
688
689      step()
690
691    @def_function.function
692    def step():
693      strategy.run(pos_and_star_args, args=(2, 3, 4))
694
695    step()
696
697    @def_function.function
698    def step():
699      strategy.run(star_args_and_kw_only, args=(2, 3), kwargs={"kw": tpu_vars})
700
701    step()
702
703    with self.assertRaisesRegex(ValueError,
704                                r"mix of positional args and \*args"):
705
706      @def_function.function
707      def step():
708        strategy.run(pos_and_star_args, args=(tpu_vars, 3, 4))
709
710      step()
711
712    with self.assertRaisesRegex(ValueError, r"Too many positional arguments"):
713
714      @def_function.function
715      def step():
716        strategy.run(named_args, args=(2, 3, 4))
717
718      step()
719
720    class DummyClass:
721
722      @def_function.function
723      def method(self, arg_1):
724        del arg_1
725
726      def step(self):
727        strategy.run(self.method, args=(tpu_vars,))
728
729    DummyClass().step()
730    # pylint:enable=function-redefined
731
732  def test_using_external_variable_inside_tf_function(self, enable_packed_var):
733    strategy = get_tpu_strategy(enable_packed_var)
734    dataset = dataset_ops.Dataset.range(
735        strategy.num_replicas_in_sync * 2,
736        output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
737    input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
738
739    v = variables.Variable(2.0)
740
741    @def_function.function
742    def train_step(data):
743      def computation(inputs):
744        return inputs + v
745      return strategy.run(computation, args=(data,))
746
747    expected_result = [[x + 2.] for x in range(0, strategy.num_replicas_in_sync)
748                      ]
749    self.assertAllEqual(
750        expected_result,
751        strategy.experimental_local_results(train_step(next(input_iterator))))
752
753  # TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
754  def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var):
755    strategy = get_tpu_strategy(enable_packed_var)
756    dataset = dataset_ops.Dataset.range(
757        strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
758            strategy.num_replicas_in_sync, drop_remainder=True)
759    input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
760
761    with strategy.scope():
762      w = variables.Variable(
763          (0.,),
764          shape=(1,),
765          trainable=False,
766          synchronization=variables.VariableSynchronization.ON_READ,
767          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
768
769    @def_function.function
770    def run(iterator):
771
772      def computation(x):
773        w.assign(x + w)
774        return w
775
776      def all_reduce(x):
777        ctx = distribution_strategy_context.get_replica_context()
778        return ctx.all_reduce("SUM", w) + x
779
780      outputs = strategy.run(computation, args=(next(iterator),))
781      outputs2 = strategy.experimental_local_results(
782          strategy.run(all_reduce, args=(outputs,)))
783      return outputs2
784
785    data = range(0, strategy.num_replicas_in_sync)
786    data_sum = sum(data)
787    expected_result = [
788        [x + data_sum] for x in range(0, strategy.num_replicas_in_sync)
789    ]
790    self.assertAllEqual(expected_result, run(input_iterator))
791    self.assertAllEqual((0.,), w.read_value())
792
793  def test_run_output_on_device(self, enable_packed_var):
794    strategy = get_tpu_strategy(enable_packed_var)
795
796    def computation(x):
797      return math_ops.square(x)
798
799    @def_function.function
800    def train_step():
801      outputs = strategy.experimental_local_results(
802          strategy.run(computation, args=(2,)))
803      return outputs
804
805    results = train_step()
806    self.assertAllEqual([4., 4.], results)
807    self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:0",
808                        results[0].backing_device)
809    self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1",
810                        results[1].backing_device)
811
812  def test_run_passing_and_returning_nones(self, enable_packed_var):
813    strategy = get_tpu_strategy(enable_packed_var)
814
815    @def_function.function
816    def train_step():
817
818      def computation(x):
819        return x
820
821      # Note that this input None is nested.
822      outputs = strategy.experimental_local_results(
823          strategy.run(computation, args=([1, [2, None]],)))
824      return outputs
825
826    results = train_step()
827
828    self.assertAllEqual(1, results[0][0])
829    self.assertAllEqual(2, results[0][1][0])
830    self.assertIsNone(results[0][1][1])
831
832  def test_run_passing_and_returning_empty_list(self, enable_packed_var):
833    strategy = get_tpu_strategy(enable_packed_var)
834
835    @def_function.function
836    def train_step():
837
838      def computation(x):
839        return x
840
841      outputs = strategy.experimental_local_results(
842          strategy.run(computation, args=([],)))
843      return outputs
844
845    self.assertEqual([], train_step()[0])
846
847  def test_run_passing_and_returning_empty_dict(self, enable_packed_var):
848    strategy = get_tpu_strategy(enable_packed_var)
849
850    @def_function.function
851    def train_step():
852
853      def computation(x):
854        return x
855
856      outputs = strategy.experimental_local_results(
857          strategy.run(computation, args=({},)))
858      return outputs
859
860    self.assertEqual({}, train_step()[0])
861
862  def test_composite_input_output(self, enable_packed_var):
863    strategy = get_tpu_strategy(enable_packed_var)
864    if strategy.num_replicas_in_sync != 2:
865      self.skipTest("Test assumes two replicas.")
866
867    with strategy.scope():
868      table = variables.Variable(
869          initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
870
871    @def_function.function
872    def sparse_lookup(iterator):
873
874      def tpu_function(sparse):
875        # Assumes dense_shape is (2, *)
876        looked_up = array_ops.gather(table, sparse.values)
877        segment_sum = math_ops.unsorted_segment_sum(
878            looked_up, sparse.indices[:, 0], 2)
879        return sparse, segment_sum
880
881      return nest.map_structure(
882          strategy.experimental_local_results,
883          strategy.run(tpu_function, args=(next(iterator),)))
884
885    def dataset_fn(_):
886      dataset = dataset_ops.Dataset.range(2)
887
888      def make_sparse(_):
889        return sparse_tensor.SparseTensor(
890            indices=array_ops.constant([[0, 0], [1, 0], [1, 1]],
891                                       dtype=dtypes.int64),
892            values=array_ops.constant([0, 0, 1], dtype=dtypes.int32),
893            dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64))
894
895      return dataset.map(make_sparse)
896
897    dataset = iter(
898        strategy.distribute_datasets_from_function(
899            dataset_fn,
900            distribute_lib.InputOptions(experimental_fetch_to_device=False)))
901
902    sparse, result = sparse_lookup(dataset)
903
904    # All replicas return identical reults.
905    for replica in range(strategy.num_replicas_in_sync):
906      self.assertIsInstance(sparse[replica], sparse_tensor.SparseTensor)
907      self.assertAllEqual(sparse[replica].indices, [[0, 0], [1, 0], [1, 1]])
908      self.assertAllEqual(sparse[replica].values, [0, 0, 1])
909      self.assertAllEqual(sparse[replica].dense_shape, [2, 2])
910      self.assertAllEqual(result[replica], [[0.0, 1.0], [3.0, 8.0]])
911
912  def test_composite_input_non_flat_output(self, enable_packed_var):
913    strategy = get_tpu_strategy(enable_packed_var)
914    if strategy.num_replicas_in_sync != 2:
915      self.skipTest("Test assumes two replicas.")
916
917    with strategy.scope():
918      table = variables.Variable(
919          initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
920
921    @def_function.function
922    def sparse_lookup(iterator):
923
924      def tpu_function(sparse):
925        # Assumes dense_shape is (2, *)
926        looked_up = array_ops.gather(table, sparse.values)
927        segment_sum = math_ops.unsorted_segment_sum(
928            looked_up, sparse.indices[:, 0], 2)
929        return {"sparse": sparse, "segment_sum": segment_sum}
930
931      return nest.map_structure(
932          strategy.experimental_local_results,
933          strategy.run(tpu_function, args=(next(iterator),)))
934
935    def dataset_fn(_):
936      dataset = dataset_ops.Dataset.range(2)
937
938      def make_sparse(_):
939        return sparse_tensor.SparseTensor(
940            indices=array_ops.constant([[0, 0], [1, 0], [1, 1]],
941                                       dtype=dtypes.int64),
942            values=array_ops.constant([0, 0, 1], dtype=dtypes.int32),
943            dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64))
944
945      return dataset.map(make_sparse)
946
947    dataset = iter(
948        strategy.distribute_datasets_from_function(
949            dataset_fn,
950            distribute_lib.InputOptions(experimental_fetch_to_device=False)))
951
952    output = sparse_lookup(dataset)
953
954    # All replicas return identical reults.
955    for replica in range(strategy.num_replicas_in_sync):
956      self.assertIsInstance(output["sparse"][replica],
957                            sparse_tensor.SparseTensor)
958      self.assertAllEqual(output["sparse"][replica].indices,
959                          [[0, 0], [1, 0], [1, 1]])
960      self.assertAllEqual(output["sparse"][replica].values, [0, 0, 1])
961      self.assertAllEqual(output["sparse"][replica].dense_shape, [2, 2])
962      self.assertAllEqual(output["segment_sum"][replica],
963                          [[0.0, 1.0], [3.0, 8.0]])
964
965  def test_composite_input_dynamic_shapes_outside_compilation(
966      self, enable_packed_var):
967    strategy = get_tpu_strategy(enable_packed_var)
968    if strategy.num_replicas_in_sync != 2:
969      self.skipTest("Test assumes two replicas.")
970
971    table = variables.Variable(
972        initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32)
973
974    @def_function.function
975    def sparse_lookup(iterator):
976
977      def tpu_function(sparse):
978        lookup = tpu.outside_compilation(
979            embedding_ops.safe_embedding_lookup_sparse, table, sparse)
980        return math_ops.reduce_sum(lookup, axis=0)
981
982      return strategy.experimental_local_results(
983          strategy.run(tpu_function, args=(next(iterator),)))
984
985    def dataset_fn(_):
986      dataset = dataset_ops.Dataset.range(2)
987
988      def make_sparse(i):
989        indices = array_ops.constant([[0, 0], [1, 0], [1, 1]],
990                                     dtype=dtypes.int64)[0:2 + i]
991        values = array_ops.constant([0, 0, 1], dtype=dtypes.int32)[0:2 + i]
992        shape = [
993            array_ops.constant([2], dtype=dtypes.int64),
994            array_ops.expand_dims(1 + i, axis=0)
995        ]
996        dense_shape = array_ops.concat(shape, axis=0)
997        return sparse_tensor.SparseTensor(
998            indices=indices, values=values, dense_shape=dense_shape)
999
1000      return dataset.map(make_sparse)
1001
1002    dataset = iter(
1003        strategy.distribute_datasets_from_function(
1004            dataset_fn,
1005            options=distribute_lib.InputOptions(
1006                experimental_fetch_to_device=False)))
1007
1008    result = sparse_lookup(dataset)
1009    self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
1010
1011  def test_composite_input_with_non_flat_components(self, enable_packed_var):
1012    strategy = get_tpu_strategy(enable_packed_var)
1013
1014    class TestCompositeTypeSpec(type_spec.TypeSpec):
1015
1016      def __init__(self, component_type_spec):
1017        self._component_type_spec = component_type_spec
1018
1019      @property
1020      def value_type(self):
1021        return TestComposite
1022
1023      def _to_components(self, value):
1024        return value.values
1025
1026      def _from_components(self, components):
1027        return TestComposite(components[0], components[1][0], components[1][1])
1028
1029      @property
1030      def _component_specs(self):
1031        return [self._component_type_spec,
1032                [self._component_type_spec, self._component_type_spec]]
1033
1034      def _serialize(self):
1035        return (self._component_type_spec,)
1036
1037    class TestComposite(composite_tensor.CompositeTensor):
1038
1039      def __init__(self, value1, value2, value3):
1040        self.values = [value1, [value2, value3]]
1041
1042      @property
1043      def _type_spec(self):
1044        return TestCompositeTypeSpec(
1045            tensor_spec.TensorSpec.from_tensor(self.values[0]))
1046
1047      def _shape_invariant_to_type_spec(self, shape):
1048        return [shape, [shape, shape]]
1049
1050    @def_function.function
1051    def test_fn(test_composite):
1052
1053      def tpu_function(composite):
1054        return (composite,
1055                composite.values[0] + (
1056                    composite.values[1][0] + composite.values[1][1])/2)
1057
1058      return nest.map_structure(
1059          strategy.experimental_local_results,
1060          strategy.run(tpu_function, args=(test_composite,)))
1061
1062    a = array_ops.constant([0.1])
1063    b = array_ops.constant([1.2])
1064    c = array_ops.constant([-0.4])
1065    test_composite = TestComposite(a, b, c)
1066
1067    composite, result = test_fn(test_composite)
1068
1069    # All replicas return identical reults.
1070    for replica in range(strategy.num_replicas_in_sync):
1071      self.assertIsInstance(composite[replica], TestComposite)
1072      self.assertAllEqual(composite[replica].values[0], a)
1073      self.assertAllEqual(composite[replica].values[1][0], b)
1074      self.assertAllEqual(composite[replica].values[1][1], c)
1075      self.assertAllEqual(result[replica], array_ops.constant([0.50000006]))
1076
1077  def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var):
1078    # Define trace_count as a list to avoid python scoping error
1079    trace_count = [0]
1080
1081    strategy = get_tpu_strategy(enable_packed_var)
1082    with strategy.scope():
1083      variable = variables.Variable(0.0)
1084
1085    @def_function.function
1086    def add_one():
1087      trace_count[0] = trace_count[0] + 1
1088      return math_ops.add(variable, constant_op.constant(1.0))
1089
1090    @def_function.function
1091    def update_variable():
1092      for device in set(strategy.extended.worker_devices):
1093        with ops.device(device):
1094          add_one()
1095
1096    with strategy.scope():
1097      update_variable.get_concrete_function()
1098      self.assertLen(strategy.extended.worker_devices, trace_count[0])
1099
1100  def test_tpu_cancellation_does_not_close_chips(self, enable_packed_var):
1101    if not FLAGS.tpu_use_tfrt:
1102      self.skipTest(
1103          "`tpu_cancellation_closes_chip only applies to TFRT TPU Runtime.")
1104    strategy = get_tpu_strategy(enable_packed_var)
1105    num_replicas = strategy.num_replicas_in_sync
1106    with strategy.scope():
1107      x = random_ops.random_normal((10240, 10240))
1108      y = random_ops.random_normal((10240, 10240))
1109
1110      v = variables.Variable(array_ops.identity(x))
1111      dist_dataset = strategy.experimental_distribute_dataset(
1112          dataset_ops.Dataset.from_tensors(y).repeat(num_replicas).batch(
1113              num_replicas))
1114      dist_iterator = iter(dist_dataset)
1115
1116      @def_function.function
1117      def train_steps(v, iterator, steps):
1118
1119        def step_fn(inputs):
1120          for val in inputs:
1121            v.assign(math_ops.matmul(v, val))
1122
1123        for _ in math_ops.range(steps):
1124          strategy.run(step_fn, args=(next(iterator),))
1125
1126      with self.assertRaises(errors.OutOfRangeError):
1127        # The iterator has num_replicas/num_replicas = 1 step only.
1128        train_steps(v, dist_iterator, 2)
1129
1130      # If TPU chips are not closed we can run the function on TPU again.
1131      w = variables.Variable(array_ops.identity(x))
1132      dist_dataset = strategy.experimental_distribute_dataset(
1133          dataset_ops.Dataset.from_tensors(y).repeat(num_replicas).batch(
1134              num_replicas))
1135      dist_iterator = iter(dist_dataset)
1136      train_steps(w, dist_iterator, 1)
1137
1138  def test_tpu_hardware_feature(self, enable_packed_var):
1139    strategy = get_tpu_strategy(enable_packed_var)
1140    self.assertIsInstance(
1141        strategy.extended.tpu_hardware_feature.embedding_feature,
1142        tpu_hardware_feature.HardwareFeature.EmbeddingFeature)
1143
1144  def test_get_tpu_cluster_resolver(self, enable_packed_var):
1145    strategy = get_tpu_strategy(enable_packed_var)
1146    self.assertIsNotNone(strategy.cluster_resolver)
1147
1148
1149@test_util.with_eager_op_as_function
1150class TPUStrategyDataPrefetchTest(test.TestCase):
1151
1152  def test_prefetch_to_device_default(self):
1153    strategy = get_tpu_strategy()
1154    dataset = dataset_ops.Dataset.range(
1155        strategy.num_replicas_in_sync * 2,
1156        output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
1157
1158    # Check default, should prefetch to TPU.
1159    dataset_item = next(iter(strategy.experimental_distribute_dataset(dataset)))
1160    dataset_location = tf_device.DeviceSpec.from_string(
1161        dataset_item.values[0].device)
1162    self.assertEqual(dataset_location.device_type, "TPU")
1163
1164  def test_prefetch_to_device_tpu(self):
1165    strategy = get_tpu_strategy()
1166    dataset = dataset_ops.Dataset.range(
1167        strategy.num_replicas_in_sync * 2,
1168        output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
1169
1170    input_options = distribute_lib.InputOptions(
1171        experimental_fetch_to_device=True)
1172    dataset_item = next(iter(strategy.experimental_distribute_dataset(
1173        dataset, options=input_options)))
1174    dataset_location = tf_device.DeviceSpec.from_string(
1175        dataset_item.values[0].device)
1176    self.assertEqual(dataset_location.device_type, "TPU")
1177
1178  def test_prefetch_to_device_cpu(self):
1179    strategy = get_tpu_strategy()
1180    dataset = dataset_ops.Dataset.range(
1181        strategy.num_replicas_in_sync * 2,
1182        output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
1183
1184    # Should be CPU when prefetch_to_device is False.
1185    input_options = distribute_lib.InputOptions(
1186        experimental_fetch_to_device=False)
1187    dataset_item = next(iter(strategy.experimental_distribute_dataset(
1188        dataset, options=input_options)))
1189    dataset_location = tf_device.DeviceSpec.from_string(
1190        dataset_item.values[0].device)
1191    self.assertEqual(dataset_location.device_type, "CPU")
1192
1193  def test_prefetch_to_device_sparse_dataset(self):
1194    strategy = get_tpu_strategy()
1195    # Values here aren't important.
1196    dataset = dataset_ops.Dataset.from_tensors(
1197        sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
1198                                   values=[1, 2, 3],
1199                                   dense_shape=[2, 2]))
1200    dataset = dataset.repeat()
1201    dataset = dataset.batch(strategy.num_replicas_in_sync)
1202
1203    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
1204      iter(strategy.experimental_distribute_dataset(dataset))
1205
1206  def test_prefetch_to_device_ragged_dataset(self):
1207    strategy = get_tpu_strategy()
1208    # Values here aren't important.
1209    dataset = dataset_ops.Dataset.from_tensors(
1210        ragged_tensor.RaggedTensor.from_row_splits(
1211            values=[1, 2, 3],
1212            row_splits=[0, 2, 3]))
1213    dataset = dataset.repeat()
1214    dataset = dataset.batch(strategy.num_replicas_in_sync)
1215
1216    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
1217      iter(strategy.experimental_distribute_dataset(dataset))
1218
1219  def test_prefetch_to_device_sparse_dataset_fn(self):
1220    strategy = get_tpu_strategy()
1221    def dataset_fn(ctx):
1222      del ctx
1223      # Values here aren't important.
1224      dataset = dataset_ops.Dataset.from_tensors(
1225          sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
1226                                     values=[1, 2, 3],
1227                                     dense_shape=[2, 2]))
1228      dataset = dataset.repeat()
1229      return dataset.batch(strategy.num_replicas_in_sync)
1230
1231    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
1232      iter(strategy.distribute_datasets_from_function(dataset_fn))
1233
1234  def test_prefetch_to_device_ragged_dataset_fn(self):
1235    strategy = get_tpu_strategy()
1236    def dataset_fn(ctx):
1237      del ctx
1238      # Values here aren't important.
1239      dataset = dataset_ops.Dataset.from_tensors(
1240          ragged_tensor.RaggedTensor.from_row_splits(
1241              values=[1, 2, 3],
1242              row_splits=[0, 2, 3]))
1243      dataset = dataset.repeat()
1244      return dataset.batch(strategy.num_replicas_in_sync)
1245
1246    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
1247      iter(strategy.distribute_datasets_from_function(dataset_fn))
1248
1249  def test_create_iterator_on_device(self):
1250
1251    @def_function.function
1252    def create_iter():
1253      with ops.device("/device:TPU:0"):
1254        return gen_dataset_ops.anonymous_iterator_v3(
1255            output_types=[dtypes.float32], output_shapes=[[]])
1256
1257    create_iter()
1258
1259
1260@test_util.with_eager_op_as_function
1261class TPUStrategyDistributionTest(
1262    strategy_test_lib.DistributionTestBase,
1263    strategy_test_lib.TwoDeviceDistributionTestBase):
1264
1265  def test_update_config_proto(self):
1266    resolver = get_tpu_cluster_resolver()
1267    remote.connect_to_cluster(resolver)
1268    tpu_strategy_util.initialize_tpu_system(resolver)
1269    strategy = tpu_lib.TPUStrategyV2(resolver)
1270
1271    config_proto = config_pb2.ConfigProto()
1272    cluster_spec = server_lib.ClusterSpec({"worker": ["fake1", "fake2"]})
1273    with test.mock.patch.object(
1274        resolver, "cluster_spec", return_value=cluster_spec):
1275      new_config = strategy.update_config_proto(config_proto)
1276
1277    # Verify cluster_def.
1278    self.assertProtoEquals(cluster_spec.as_cluster_def(),
1279                           new_config.cluster_def)
1280
1281    # Verify isolate_session_state
1282    self.assertTrue(new_config.isolate_session_state)
1283
1284  def test_make_input_fn_iterable(self):
1285    dataset_fn = lambda: dataset_ops.Dataset.range(10)
1286    expected_values = [[i, i+1] for i in range(0, 10, 2)]
1287    distribution = get_tpu_strategy()
1288    input_fn = self._input_fn_to_test_input_context(
1289        dataset_fn,
1290        expected_num_replicas_in_sync=2,
1291        expected_num_input_pipelines=1,
1292        expected_input_pipeline_id=0)
1293    self._test_input_fn_iterable(distribution, input_fn, expected_values)
1294
1295  def test_make_input_fn_iterator(self):
1296    dataset_fn = lambda: dataset_ops.Dataset.range(10)
1297    expected_values = [[i, i+1] for i in range(0, 10, 2)]
1298    distribution = get_tpu_strategy()
1299    input_fn = self._input_fn_to_test_input_context(
1300        dataset_fn,
1301        expected_num_replicas_in_sync=2,
1302        expected_num_input_pipelines=1,
1303        expected_input_pipeline_id=0)
1304    iterator = distribution.make_input_fn_iterator(input_fn)
1305    self._test_input_fn_iterator(
1306        iterator,
1307        distribution.extended.worker_devices,
1308        expected_values)
1309
1310  def test_num_replicas_in_sync(self):
1311    strategy = get_tpu_strategy()
1312    self.assertEqual(2, strategy.num_replicas_in_sync)
1313
1314  def test_call_and_merge_exceptions(self):
1315    strategy = get_tpu_strategy()
1316    self._test_call_and_merge_exceptions(strategy)
1317
1318  def test_numpy_dataset(self):
1319    strategy = get_tpu_strategy()
1320    self._test_numpy_dataset(strategy, run_in_function=True)
1321
1322  def test_global_step_update(self):
1323    strategy = get_tpu_strategy()
1324    self._test_global_step_update(strategy)
1325
1326  def test_run(self):
1327    strategy = get_tpu_strategy()
1328    self._test_run(strategy, run_in_function=True)
1329
1330  def test_summary_for_replica_zero_only(self):
1331    strategy = get_tpu_strategy()
1332    self._test_summary_for_replica_zero_only(strategy)
1333
1334  def test_all_reduce_sum(self):
1335    strategy = get_tpu_strategy()
1336    self._test_all_reduce_sum(strategy, run_in_function=True)
1337
1338  def test_all_reduce_sum_gradients(self):
1339    strategy = get_tpu_strategy()
1340    self._test_all_reduce_sum_gradients(strategy, run_in_function=True)
1341
1342  def test_all_reduce_sum_gradient_tape(self):
1343    strategy = get_tpu_strategy()
1344    self._test_all_reduce_sum_gradient_tape(strategy, run_in_function=True)
1345
1346  def test_all_reduce_mean(self):
1347    strategy = get_tpu_strategy()
1348    self._test_all_reduce_mean(strategy, run_in_function=True)
1349
1350  def test_all_reduce_mean_gradients(self):
1351    strategy = get_tpu_strategy()
1352    self._test_all_reduce_mean_gradients(strategy, run_in_function=True)
1353
1354  def test_all_reduce_mean_gradient_tape(self):
1355    strategy = get_tpu_strategy()
1356    self._test_all_reduce_mean_gradient_tape(strategy, run_in_function=True)
1357
1358  def test_reduce(self):
1359    strategy = get_tpu_strategy()
1360
1361    inputs = strategy.make_input_fn_iterator(
1362        lambda _: dataset_ops.Dataset.from_tensor_slices([2., 3.]))
1363
1364    self.evaluate(inputs.initialize())
1365    per_replica_outputs = strategy.run(
1366        def_function.function(math_ops.square), args=(next(inputs),))
1367
1368    with strategy.scope():
1369      mean = strategy.reduce(reduce_util.ReduceOp.MEAN, per_replica_outputs,
1370                             axis=None)
1371      self.assertEqual(6.5, self.evaluate(mean))
1372
1373  def test_constraint(self):
1374    strategy = get_tpu_strategy()
1375
1376    with strategy.scope():
1377      variable = variables.Variable(initial_value=2.,
1378                                    constraint=lambda x: 0. * x + 1.)
1379    self.assertEqual(variable.value().numpy(), 2)
1380
1381    @def_function.function
1382    def update_variable():
1383      variable.assign_add(1)
1384      variable.assign(variable.constraint(variable))
1385
1386    update_variable()
1387    self.assertEqual(variable.value().numpy(), 1)
1388
1389  def test_trainable_variables(self):
1390    strategy = get_tpu_strategy()
1391    self._test_trainable_variable(strategy)
1392
1393
1394@test_util.with_eager_op_as_function
1395class DeviceAssignmentTest(test.TestCase):
1396
1397  def test_core_assignment(self):
1398    resolver = get_tpu_cluster_resolver()
1399    remote.connect_to_cluster(resolver)
1400    topology = tpu_strategy_util.initialize_tpu_system(resolver)
1401    device_assignment = device_assignment_lib.DeviceAssignment(
1402        topology, core_assignment=[[[0, 0, 0, 0]]])
1403    self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
1404    self.assertEqual(1, device_assignment.num_cores_per_replica)
1405    self.assertEqual(1, device_assignment.num_replicas)
1406    self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device())
1407    self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device())
1408
1409  def test_device_assignment_strategy_properties(self):
1410    resolver = get_tpu_cluster_resolver()
1411    remote.connect_to_cluster(resolver)
1412    topology = tpu_strategy_util.initialize_tpu_system(resolver)
1413    device_assignment = device_assignment_lib.DeviceAssignment(
1414        topology, core_assignment=[[[0, 0, 0, 0]]])
1415    strategy = tpu_lib.TPUStrategyV2(
1416        resolver,
1417        experimental_device_assignment=device_assignment)
1418    self.assertEqual(strategy.extended.num_hosts, 1)
1419    self.assertEqual(strategy.num_replicas_in_sync, 1)
1420    self.assertEqual(strategy.extended.num_replicas_per_host, 1)  # pylint: disable=protected-access
1421
1422  def test_device_assignment_constants(self):
1423    resolver = get_tpu_cluster_resolver()
1424    remote.connect_to_cluster(resolver)
1425    topology = tpu_strategy_util.initialize_tpu_system(resolver)
1426    device_assignment = device_assignment_lib.DeviceAssignment(
1427        topology,
1428        core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
1429    self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment)
1430    self.assertEqual(1, device_assignment.num_cores_per_replica)
1431    self.assertEqual(1, device_assignment.num_replicas)
1432    self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device())
1433    self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device())
1434
1435  def test_variables_mismatched_device_assignment(self):
1436    resolver = get_tpu_cluster_resolver()
1437    remote.connect_to_cluster(resolver)
1438    topology = tpu_strategy_util.initialize_tpu_system(resolver)
1439
1440    strategy0 = tpu_lib.TPUStrategyV2(resolver)
1441    self.assertEqual(
1442        ("/job:localhost/replica:0/task:0/device:TPU:0",
1443         "/job:localhost/replica:0/task:0/device:TPU:1"),
1444        strategy0.extended.worker_devices)
1445
1446    with strategy0.scope():
1447      v = variables.Variable(1.)
1448
1449    v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.)
1450
1451    with self.cached_session():
1452      self.evaluate(variables.global_variables_initializer())
1453      self.evaluate(v1_assign_op)
1454      self.assertAllEqual([1., 42.],
1455                          self.evaluate(
1456                              strategy0.experimental_local_results(v)))
1457
1458    # Second strategy has devices reversed relative to the first.
1459    device_assignment = device_assignment_lib.DeviceAssignment(
1460        topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]])
1461    strategy1 = tpu_lib.TPUStrategyV2(
1462        resolver,
1463        experimental_device_assignment=device_assignment)
1464    self.assertEqual(
1465        ("/job:localhost/replica:0/task:0/device:TPU:1",
1466         "/job:localhost/replica:0/task:0/device:TPU:0"),
1467        strategy1.extended.worker_devices)
1468
1469    v_read = strategy1.run(def_function.function(v.read_value))
1470
1471    with self.cached_session():
1472      self.assertAllEqual([42., 1.],
1473                          self.evaluate(
1474                              strategy0.experimental_local_results(v_read)))
1475
1476
1477if __name__ == "__main__":
1478  test.main()
1479