xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/tpu_strategy_compilation_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPUStrategy in regards to compiling programs."""
16
17from tensorflow.python.distribute import tpu_strategy as tpu_lib
18from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
19from tensorflow.python.eager import def_function
20from tensorflow.python.eager import remote
21from tensorflow.python.eager import test
22from tensorflow.python.framework import constant_op
23from tensorflow.python.platform import flags
24from tensorflow.python.tpu import tpu_strategy_util
25
26FLAGS = flags.FLAGS
27flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
28flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
29flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
30
31
32def get_tpu_cluster_resolver():
33  resolver = tpu_cluster_resolver.TPUClusterResolver(
34      tpu=FLAGS.tpu,
35      zone=FLAGS.zone,
36      project=FLAGS.project,
37  )
38  return resolver
39
40
41def get_tpu_strategy():
42  resolver = get_tpu_cluster_resolver()
43  remote.connect_to_cluster(resolver)
44  tpu_strategy_util.initialize_tpu_system(resolver)
45  strategy = tpu_lib.TPUStrategyV2(resolver)
46  return strategy
47
48
49# TODO(b/158494076): Merge this test back into TPUStrategy tests
50# (tpu_strategy_test) once MLIR bridge is enabled by default.
51class TPUStrategyCompilationTest(test.TestCase):
52
53  def test_functions_compile_same_signature(self):
54    """Tests compiling different functions with the same signature."""
55    strategy = get_tpu_strategy()
56
57    @def_function.function
58    def return_one():
59
60      def computation():
61        return constant_op.constant(1)
62
63      return strategy.run(computation)
64
65    @def_function.function
66    def return_two():
67
68      def computation():
69        return constant_op.constant(2)
70
71      return strategy.run(computation)
72
73    expected_result_ones = [1 for _ in range(0, strategy.num_replicas_in_sync)]
74    self.assertAllEqual(expected_result_ones,
75                        strategy.experimental_local_results(return_one()))
76
77    expected_result_twos = [2 for _ in range(0, strategy.num_replicas_in_sync)]
78    self.assertAllEqual(expected_result_twos,
79                        strategy.experimental_local_results(return_two()))
80
81
82if __name__ == "__main__":
83  test.main()
84