1# Owner(s): ["module: inductor"] 2import math 3import unittest 4 5import torch 6from torch._inductor import config 7from torch._inductor.test_case import run_tests, TestCase 8from torch.testing._internal.inductor_utils import HAS_CPU 9 10 11def dummy_fn(x): 12 return torch.sigmoid(x + math.pi) / 10.0 13 14 15class DummyModule(torch.nn.Module): 16 def forward(self, x): 17 return dummy_fn(x) 18 19 20class TestInductorConfig(TestCase): 21 @classmethod 22 def setUpClass(cls): 23 super().setUpClass() 24 cls._saved_config = config.save_config() 25 26 def tearDown(self): 27 super().tearDown() 28 config.load_config(self._saved_config) 29 30 def test_set(self): 31 config.max_fusion_size = 13337 32 self.assertEqual(config.max_fusion_size, 13337) 33 self.assertEqual(config.shallow_copy_dict()["max_fusion_size"], 13337) 34 config.max_fusion_size = 32 35 self.assertEqual(config.max_fusion_size, 32) 36 37 # a nested config 38 prior = config.triton.cudagraphs 39 config.triton.cudagraphs = not prior 40 self.assertEqual(config.triton.cudagraphs, not prior) 41 self.assertEqual(config.shallow_copy_dict()["triton.cudagraphs"], not prior) 42 43 def test_save_load(self): 44 config.max_fusion_size = 123 45 config.triton.cudagraphs = True 46 saved1 = config.save_config() 47 config.max_fusion_size = 321 48 config.triton.cudagraphs = False 49 saved2 = config.save_config() 50 51 self.assertEqual(config.max_fusion_size, 321) 52 self.assertEqual(config.triton.cudagraphs, False) 53 config.load_config(saved1) 54 self.assertEqual(config.max_fusion_size, 123) 55 self.assertEqual(config.triton.cudagraphs, True) 56 config.load_config(saved2) 57 self.assertEqual(config.max_fusion_size, 321) 58 self.assertEqual(config.triton.cudagraphs, False) 59 60 def test_hasattr(self): 61 self.assertTrue(hasattr(config, "max_fusion_size")) 62 self.assertFalse(hasattr(config, "missing_name")) 63 64 def test_invalid_names(self): 65 self.assertRaises(AttributeError, lambda: config.does_not_exist) 66 self.assertRaises(AttributeError, lambda: config.triton.does_not_exist) 67 68 def store1(): 69 config.does_not_exist = True 70 71 def store2(): 72 config.triton.does_not_exist = True 73 74 self.assertRaises(AttributeError, store1) 75 self.assertRaises(AttributeError, store2) 76 77 def test_patch(self): 78 with config.patch(max_fusion_size=456): 79 self.assertEqual(config.max_fusion_size, 456) 80 with config.patch(max_fusion_size=789): 81 self.assertEqual(config.max_fusion_size, 789) 82 self.assertEqual(config.max_fusion_size, 456) 83 84 with config.patch({"cpp.threads": 9000, "max_fusion_size": 9001}): 85 self.assertEqual(config.cpp.threads, 9000) 86 self.assertEqual(config.max_fusion_size, 9001) 87 with config.patch("cpp.threads", 8999): 88 self.assertEqual(config.cpp.threads, 8999) 89 self.assertEqual(config.cpp.threads, 9000) 90 91 @unittest.skipIf(not HAS_CPU, "requires C++ compiler") 92 def test_compile_api(self): 93 # these are mostly checking config processing doesn't blow up with exceptions 94 x = torch.randn(8) 95 y = dummy_fn(x) 96 checks = [ 97 {}, 98 {"mode": "default"}, 99 {"mode": "reduce-overhead"}, 100 {"mode": "max-autotune"}, 101 { 102 "options": { 103 "max-fusion-size": 128, 104 "unroll_reductions_threshold": 32, 105 "triton.cudagraphs": False, 106 } 107 }, 108 {"dynamic": True}, 109 {"fullgraph": True, "backend": "inductor"}, 110 {"disable": True}, 111 ] 112 113 for kwargs in checks: 114 torch._dynamo.reset() 115 opt_fn = torch.compile(dummy_fn, **kwargs) 116 torch.testing.assert_allclose( 117 opt_fn(x), y, msg=f"torch.compile(..., **{kwargs!r}) failed" 118 ) 119 120 def test_get_compiler_config(self): 121 from torch._inductor import config as inductor_default_config 122 123 default_cudagraphs = inductor_default_config._default["triton.cudagraphs"] 124 125 # nn.Module: should update default config with a new value 126 model = DummyModule() 127 optimized_module = torch.compile( 128 model, options={"triton.cudagraphs": not default_cudagraphs} 129 ) 130 compiler_config = optimized_module.get_compiler_config() 131 self.assertEqual(compiler_config["triton.cudagraphs"], not default_cudagraphs) 132 133 # nn.Module: keep default config 134 model = DummyModule() 135 optimized_module = torch.compile(model) 136 compiler_config = optimized_module.get_compiler_config() 137 self.assertEqual( 138 compiler_config["triton.cudagraphs"], 139 default_cudagraphs, 140 ) 141 142 # compile user func: should update default config with a new value 143 optimized_module = torch.compile( 144 dummy_fn, options={"triton.cudagraphs": not default_cudagraphs} 145 ) 146 compiler_config = optimized_module.get_compiler_config() 147 self.assertEqual(compiler_config["triton.cudagraphs"], not default_cudagraphs) 148 149 # compile user func: keep default config 150 optimized_module = torch.compile(dummy_fn) 151 compiler_config = optimized_module.get_compiler_config() 152 self.assertEqual( 153 compiler_config["triton.cudagraphs"], 154 default_cudagraphs, 155 ) 156 157 # backend=eager: expect None 158 optimized_module = torch.compile(dummy_fn, backend="eager") 159 compiler_config = optimized_module.get_compiler_config() 160 self.assertTrue(compiler_config is None) 161 162 def test_compile_api_passes_config(self): 163 # ensure configs are actually passed down to inductor 164 self.assertRaises( 165 torch._dynamo.exc.BackendCompilerFailed, 166 lambda: torch.compile(dummy_fn, options={"_raise_error_for_testing": True})( 167 torch.randn(10) 168 ), 169 ) 170 171 def test_api_options(self): 172 reduce_overhead_opts = torch._inductor.list_mode_options("reduce-overhead") 173 self.assertEqual(reduce_overhead_opts["triton.cudagraphs"], True) 174 self.assertEqual(reduce_overhead_opts.get("max_autotune", False), False) 175 176 max_autotune_opts = torch._inductor.list_mode_options("max-autotune") 177 self.assertEqual(max_autotune_opts["max_autotune"], True) 178 self.assertEqual(max_autotune_opts["triton.cudagraphs"], True) 179 180 max_autotune_opts = torch._inductor.list_mode_options( 181 "max-autotune", dynamic=True 182 ) 183 self.assertEqual(max_autotune_opts["max_autotune"], True) 184 self.assertEqual(max_autotune_opts["triton.cudagraphs"], True) 185 186 max_autotune_no_cudagraphs_opts = torch._inductor.list_mode_options( 187 "max-autotune-no-cudagraphs" 188 ) 189 self.assertEqual(max_autotune_no_cudagraphs_opts["max_autotune"], True) 190 self.assertEqual( 191 max_autotune_no_cudagraphs_opts.get("triton.cudagraphs", False), False 192 ) 193 194 def test_invalid_backend(self): 195 self.assertRaises( 196 torch._dynamo.exc.InvalidBackend, 197 lambda: torch.compile(dummy_fn, backend="does_not_exist")(torch.randn(10)), 198 ) 199 200 def test_non_inductor_backend(self): 201 def assert_options(expected_mode=None, expected_options=None): 202 def backend(gm, _, *, mode=None, options=None): 203 nonlocal call_count 204 self.assertEqual(mode, expected_mode) 205 self.assertEqual(options, expected_options) 206 call_count += 1 207 return gm 208 209 return backend 210 211 inp = torch.randn(8) 212 213 def fn(x): 214 return x + 1 215 216 for mode, options in [ 217 (None, None), 218 ("fast-mode", None), 219 (None, {"foo": "bar"}), 220 ]: 221 call_count = 0 222 torch.compile( 223 fn, backend=assert_options(mode, options), mode=mode, options=options 224 )(inp) 225 torch._dynamo.reset() 226 self.assertEqual(call_count, 1) 227 228 229if __name__ == "__main__": 230 run_tests() 231