1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for TPUStrategy in regards to compiling programs.""" 16 17from tensorflow.python.distribute import tpu_strategy as tpu_lib 18from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 19from tensorflow.python.eager import def_function 20from tensorflow.python.eager import remote 21from tensorflow.python.eager import test 22from tensorflow.python.framework import constant_op 23from tensorflow.python.platform import flags 24from tensorflow.python.tpu import tpu_strategy_util 25 26FLAGS = flags.FLAGS 27flags.DEFINE_string("tpu", "", "Name of TPU to connect to.") 28flags.DEFINE_string("project", None, "Name of GCP project with TPU.") 29flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.") 30 31 32def get_tpu_cluster_resolver(): 33 resolver = tpu_cluster_resolver.TPUClusterResolver( 34 tpu=FLAGS.tpu, 35 zone=FLAGS.zone, 36 project=FLAGS.project, 37 ) 38 return resolver 39 40 41def get_tpu_strategy(): 42 resolver = get_tpu_cluster_resolver() 43 remote.connect_to_cluster(resolver) 44 tpu_strategy_util.initialize_tpu_system(resolver) 45 strategy = tpu_lib.TPUStrategyV2(resolver) 46 return strategy 47 48 49# TODO(b/158494076): Merge this test back into TPUStrategy tests 50# (tpu_strategy_test) once MLIR bridge is enabled by default. 51class TPUStrategyCompilationTest(test.TestCase): 52 53 def test_functions_compile_same_signature(self): 54 """Tests compiling different functions with the same signature.""" 55 strategy = get_tpu_strategy() 56 57 @def_function.function 58 def return_one(): 59 60 def computation(): 61 return constant_op.constant(1) 62 63 return strategy.run(computation) 64 65 @def_function.function 66 def return_two(): 67 68 def computation(): 69 return constant_op.constant(2) 70 71 return strategy.run(computation) 72 73 expected_result_ones = [1 for _ in range(0, strategy.num_replicas_in_sync)] 74 self.assertAllEqual(expected_result_ones, 75 strategy.experimental_local_results(return_one())) 76 77 expected_result_twos = [2 for _ in range(0, strategy.num_replicas_in_sync)] 78 self.assertAllEqual(expected_result_twos, 79 strategy.experimental_local_results(return_two())) 80 81 82if __name__ == "__main__": 83 test.main() 84