1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.distribute.combinations.""" 16 17import importlib 18import os 19import sys 20import unittest 21 22from absl.testing import parameterized 23 24from tensorflow.python.distribute import combinations 25from tensorflow.python.distribute import test_util 26from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver 27from tensorflow.python.eager import context 28from tensorflow.python.framework import combinations as framework_combinations 29from tensorflow.python.platform import test 30 31 32class ClusterCombinationTest(test.TestCase, parameterized.TestCase): 33 # For this test we need to use `framework.test_combinations` because our 34 # `generate` eats the cluster parameters. 35 # 36 # Note that we don't have a standalone combination for ClusterParameters, so 37 # we should use GPUCombination which contains it. 38 39 @framework_combinations.generate( # pylint: disable=redundant-keyword-arg 40 framework_combinations.combine(distribution=[ 41 combinations.NamedDistribution( 42 "HasClusterParams", lambda: None, has_chief=True, num_workers=2), 43 ]), 44 test_combinations=(combinations.ClusterCombination(),)) 45 def testClusterParams(self, distribution, has_chief, num_workers): 46 self.assertTrue(has_chief) 47 self.assertEqual(num_workers, 2) 48 49 @framework_combinations.generate( # pylint: disable=redundant-keyword-arg 50 framework_combinations.combine(distribution=[ 51 combinations.NamedDistribution("NoClusterParams", lambda: None), 52 ]), 53 test_combinations=(combinations.ClusterCombination(),)) 54 def testClusterParamsHasDefault(self, distribution, has_chief, num_workers): 55 self.assertFalse(has_chief) 56 self.assertEqual(num_workers, 1) 57 58 @framework_combinations.generate( # pylint: disable=redundant-keyword-arg 59 framework_combinations.combine(v=1), 60 test_combinations=(combinations.ClusterCombination(),)) 61 def testClusterParamsNoStrategy(self, v, has_chief, num_workers): 62 self.assertFalse(has_chief) 63 self.assertEqual(num_workers, 1) 64 65 @framework_combinations.generate( # pylint: disable=redundant-keyword-arg 66 framework_combinations.combine(distribution=[ 67 combinations.NamedDistribution( 68 "WithClusterParams", lambda: None, has_chief=True, num_workers=2), 69 combinations.NamedDistribution("WithoutClusterParams", lambda: None), 70 ]), 71 test_combinations=(combinations.ClusterCombination(),)) 72 def testClusterParamsAreOptional(self, distribution): 73 # If combinations library doesn't raise an exception, the test is passed. 74 pass 75 76 @framework_combinations.generate( # pylint: disable=redundant-keyword-arg 77 framework_combinations.combine( 78 ds1=combinations.NamedDistribution( 79 "Strategy1", lambda: None, has_chief=True, num_workers=0), 80 ds2=combinations.NamedDistribution( 81 "Strategy2", lambda: None, has_chief=False, num_workers=1), 82 ds3=combinations.NamedDistribution( 83 "Strategy3", lambda: None, has_chief=True, num_workers=0), 84 ), 85 test_combinations=(combinations.ClusterCombination(),)) 86 def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3): 87 # If combinations library doesn't raise an exception, the test is passed. 88 pass 89 90 @combinations.generate(combinations.combine(num_workers=2,)) 91 def testUseWithoutStrategy(self): 92 # There's no perfect way to check if the test runs in a subprocess. We 93 # approximate by checking the presence of TF_CONFIG, which is normally not 94 # set to the main process. 95 self.assertNotEqual(os.getenv("TF_CONFIG"), "") 96 97 98@combinations.generate(combinations.combine(num_workers=2)) 99class ClusterCombinationTestEnvTest(test.TestCase, parameterized.TestCase): 100 101 def setUp(self): 102 # Note that test case fixtures are executed in both the main process and 103 # worker processes. 104 super().setUp() 105 if combinations.in_main_process(): 106 combinations.env().tf_data_service_dispatcher = "localhost" 107 108 def testTfDataServiceDispatcher(self): 109 self.assertEqual(combinations.env().tf_data_service_dispatcher, "localhost") 110 111 def testUpdateEnvInWorker(self): 112 with self.assertRaises(ValueError): 113 combinations.env().tf_data_service_dispatcher = "localhost" 114 115 116# unittest.expectedFailure doesn't work with parameterized test methods, so we 117# have to decorate the class instead. 118@unittest.expectedFailure 119class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase): 120 121 @framework_combinations.generate( # pylint: disable=redundant-keyword-arg 122 framework_combinations.combine( 123 ds1=combinations.NamedDistribution( 124 "Strategy1", lambda: None, has_chief=True, num_workers=2), 125 ds2=combinations.NamedDistribution( 126 "Strategy2", lambda: None, has_chief=True, num_workers=2), 127 ), 128 test_combinations=(combinations.ClusterCombination(),)) 129 def testMultipleDistributionMultiWorker(self, ds1, ds2): 130 # combinations library should raise an exception. 131 pass 132 133 134# Tests that we *actually* run the test method in multiple workers instead of 135# just passing silently. More importantly, it verifies that the test can fail. 136# Note that unittest.expectedFailure doesn't work with parameterized test 137# methods, so we have to decorate the class instead. 138@unittest.expectedFailure 139class CombinationsExpectedFailureTest(test.TestCase, parameterized.TestCase): 140 141 @combinations.generate( 142 combinations.combine(distribution=[ 143 combinations.NamedDistribution( 144 "OneChiefOneWorker", lambda: None, has_chief=True, num_workers=1), 145 combinations.NamedDistribution( 146 "TwoWorkers", lambda: None, has_chief=False, num_workers=2), 147 ])) 148 def testMultiWorkerCanFail(self, distribution): 149 resolver = tfconfig_cluster_resolver.TFConfigClusterResolver() 150 # This should fail. 151 self.assertIsNone(resolver.task_id) 152 153 154# Tests that we *actually* run the test method in multiple workers instead of 155# just passing silently. More importantly, it verifies that the test can fail. 156# Note that unittest.expectedFailure doesn't work with parameterized test 157# methods, so we have to decorate the class instead. 158@unittest.expectedFailure 159@combinations.generate( 160 combinations.combine(distribution=[ 161 combinations.NamedDistribution( 162 "OneChiefOneWorker", lambda: None, has_chief=True, num_workers=1), 163 combinations.NamedDistribution( 164 "TwoWorkers", lambda: None, has_chief=False, num_workers=2), 165 ])) 166class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase, 167 parameterized.TestCase): 168 169 def test(self, distribution): 170 resolver = tfconfig_cluster_resolver.TFConfigClusterResolver() 171 # This should fail. 172 self.assertIsNone(resolver.task_id) 173 174 175class TfFunctionTest(test.TestCase, parameterized.TestCase): 176 177 @combinations.generate( 178 combinations.combine( 179 tf_function_1=combinations.tf_function, 180 tf_function_2=combinations.no_tf_function, 181 mode="eager", 182 )) 183 def testFunc(self, tf_function_1, tf_function_2): 184 185 @tf_function_1 186 def foo(): 187 self.assertFalse(context.executing_eagerly()) 188 189 @tf_function_2 190 def bar(): 191 self.assertTrue(context.executing_eagerly()) 192 193 foo() 194 bar() 195 196 197class ModuleInitializingTest(test.TestCase, parameterized.TestCase): 198 199 def testSysArgvClearedIsFine(self): 200 original_argv = list(sys.argv) 201 sys.argv.clear() 202 importlib.reload(combinations) 203 sys.argv = original_argv 204 205 206class ShareGPUTest(test.TestCase, parameterized.TestCase): 207 208 def setUp(self): 209 super().setUp() 210 if combinations.in_main_process(): 211 num_gpus = combinations.env().total_phsyical_gpus 212 if num_gpus != 2 and num_gpus != 4: 213 self.skipTest("requires 2 or 4 GPUs") 214 215 # Test cases are annotated with required_gpus only for them to run in gpu 216 # targets, otherwise they will be skipped. 217 218 @combinations.generate( 219 combinations.combine(num_workers=2, required_gpus=1, share_gpu=True)) 220 def testShareGPU(self): 221 self.assertLen(context.context().list_physical_devices("GPU"), 222 combinations.env().total_phsyical_gpus) 223 224 @combinations.generate(combinations.combine(num_workers=2, required_gpus=1)) 225 def testShareGPUByDefault(self): 226 self.assertLen(context.context().list_physical_devices("GPU"), 227 combinations.env().total_phsyical_gpus) 228 229 @combinations.generate( 230 combinations.combine(num_workers=2, required_gpus=1, share_gpu=False)) 231 def testNotShareGPU(self): 232 self.assertLen(context.context().list_physical_devices("GPU"), 233 combinations.env().total_phsyical_gpus / 2) 234 235 236if __name__ == "__main__": 237 test_util.main() 238