1# Owner(s): ["module: dynamo"] 2import sys 3import unittest 4from unittest.mock import MagicMock, patch 5 6import torch 7import torch._dynamo 8import torch._dynamo.backends 9import torch._dynamo.test_case 10from torch._dynamo.backends.debugging import ExplainWithBackend 11from torch._dynamo.backends.onnxrt import has_onnxruntime 12from torch._dynamo.backends.tvm import has_tvm 13from torch._dynamo.testing import same 14from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module 15from torch.testing._internal.inductor_utils import HAS_CUDA 16 17 18requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 19 20 21class Seq(torch.nn.Module): 22 def __init__(self) -> None: 23 super().__init__() 24 self.layers = torch.nn.Sequential( 25 torch.nn.Linear(10, 10), 26 torch.nn.ReLU(), 27 torch.nn.Linear(10, 10), 28 torch.nn.Sigmoid(), 29 ) 30 31 def forward(self, x): 32 return self.layers(x) 33 34 35class Conv_Bn_Relu(torch.nn.Module): 36 def __init__(self, in_channels, out_channels, **kwargs): 37 super().__init__() 38 self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 39 self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) 40 self.relu = torch.nn.ReLU() 41 42 def forward(self, x): 43 return self.relu(self.bn(self.conv(x))) 44 45 46class TestOptimizations(torch._dynamo.test_case.TestCase): 47 def test_example_inputs(self): 48 def fn(a, bc, d): 49 b, c = bc 50 return a / d - b / c 51 52 def compiler_fn(graph, example_inputs): 53 nonlocal r1 54 r1 = graph(*example_inputs)[0] 55 return graph.forward 56 57 a = torch.empty(2).fill_(1) 58 b = torch.empty(2).fill_(2) 59 c = torch.empty(2).fill_(3) 60 d = 4 61 r1 = None 62 r2 = fn(a, (b, c), d) 63 opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) 64 r3 = opt_fn(a, (b, c), d) 65 66 self.assertIsNotNone(r1) 67 self.assertEqual(r1.size(), r2.size()) 68 self.assertEqual(r1.stride(), r2.stride()) 69 self.assertEqual(r1.dtype, r2.dtype) 70 71 self.assertEqual(r1.size(), r3.size()) 72 self.assertEqual(r1.stride(), r3.stride()) 73 self.assertEqual(r1.dtype, r3.dtype) 74 75 def test_example_inputs_runtime_use(self): 76 def fn(a, bc, d): 77 b, c = bc 78 return a / d - b / c 79 80 def compiler_fn(graph, example_inputs): 81 def fwd(*args): 82 nonlocal r1 83 r = graph.forward(*args) 84 r1 = r[0] 85 return r 86 87 return fwd 88 89 a = torch.empty(2).fill_(1) 90 b = torch.empty(2).fill_(2) 91 c = torch.empty(2).fill_(3) 92 d = 4 93 r1 = None 94 r2 = fn(a, (b, c), d) 95 opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) 96 r3 = opt_fn(a, (b, c), d) 97 98 self.assertIsNotNone(r1) 99 self.assertTrue(same(r1, r2)) 100 self.assertTrue(same(r1, r3)) 101 102 def _check_backend_works(self, backend, options=None): 103 model = Seq().eval() 104 input = torch.randn(2, 10) 105 r1 = model(input) 106 r2 = torch.compile(model, backend=backend, options=options)(input) 107 self.assertTrue(same(r1, r2.float(), tol=0.01)) 108 109 def test_eager(self): 110 self._check_backend_works("eager") 111 112 def test_eager_noexcept(self): 113 self._check_backend_works("eager_noexcept") 114 115 @_force_skip_lazy_graph_module() 116 def test_torchscript(self): 117 self._check_backend_works("ts") 118 119 def test_aot_eager(self): 120 self._check_backend_works("aot_eager") 121 122 def test_aot_eager_decomp_partition(self): 123 self._check_backend_works("aot_eager_decomp_partition") 124 125 @_force_skip_lazy_graph_module() 126 def test_aot_ts(self): 127 self._check_backend_works("aot_ts") 128 129 @requires_cuda 130 def test_aot_cudagraphs(self): 131 self._check_backend_works("cudagraphs") 132 133 @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") 134 def test_onnxrt(self): 135 self._check_backend_works("onnxrt") 136 137 @unittest.skipIf(not has_tvm(), "requires tvm") 138 def test_tvm(self): 139 self._check_backend_works("tvm") 140 self._check_backend_works("tvm", options={"scheduler": None}) 141 self._check_backend_works("tvm", options={"opt_level": 0}) 142 143 def test_list_backends(self): 144 self.assertIn("inductor", torch._dynamo.list_backends()) 145 self.assertIn("inductor", torch._dynamo.list_backends(exclude_tags=None)) 146 self.assertNotIn("eager", torch._dynamo.list_backends()) 147 self.assertNotIn("eager", torch._dynamo.list_backends(exclude_tags=["debug"])) 148 self.assertIn("eager", torch._dynamo.list_backends(exclude_tags=[])) 149 150 151class NormalizeIRTests(torch._dynamo.test_case.TestCase): 152 def test_inplace_normalize(self): 153 def fn(a, b): 154 x = torch.cos(a) 155 x += b 156 return torch.sin(x) 157 158 a = torch.randn(10) 159 b = torch.randn(10).to(torch.float64) 160 161 ref = fn(a, b) 162 163 optimized_fn = torch._dynamo.optimize("aot_eager")(fn) 164 res = optimized_fn(a, b) 165 self.assertTrue(same(ref, res)) 166 167 168class MPSNotSupportedTest(torch._dynamo.test_case.TestCase): 169 @unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") 170 def test_mps_not_supported(self): 171 model = Seq().to("mps") 172 example_input = torch.randn(1, 10).to("mps") 173 self.assertRaises( 174 RuntimeError, 175 lambda: torch.compile(model, backend="inductor")(example_input), 176 ) 177 178 179class TestExplainWithBackend(torch._dynamo.test_case.TestCase): 180 def test_explain_with_backend(self): 181 def fn3(x): 182 x = torch.sin(x) 183 torch._dynamo.graph_break() 184 x = torch.sin(x) 185 return x 186 187 def fn2(x): 188 x = torch.cos(x) 189 x = fn3(x) 190 x = torch.cos(x) 191 return x 192 193 def fn1(x): 194 x = torch.tan(x) 195 x = fn2(x) 196 x = torch.tan(x) 197 return x 198 199 def fn(x): 200 x = torch.sigmoid(x) 201 x = fn1(x) 202 x = torch.sigmoid(x) 203 return x 204 205 # Wrap TorchInductor with explain backend 206 eb = ExplainWithBackend("inductor") 207 optimized_fn = torch.compile(fn, backend=eb) 208 input_tensor = torch.randn(5) 209 result = optimized_fn(input_tensor) 210 211 # Check that fn still produces the same output when wrapped by ExplainWithBackend 212 self.assertTrue(torch.allclose(result, fn(input_tensor))) 213 214 # Verify ExplainOutput object contents, output might change but make sure these fields are present 215 explain_output = eb.output() 216 explain_str = str(explain_output) 217 self.assertIn("Graph Count", explain_str) 218 self.assertIn("Graph Break Count", explain_str) 219 self.assertIn("Op Count", explain_str) 220 self.assertIn("Break Reasons", explain_str) 221 222 # Verify that for the given functions above, we report the correct number of graphs, graph breaks, and ops 223 self.assertEqual(8, explain_output.graph_count) 224 self.assertEqual(7, explain_output.graph_break_count) 225 self.assertEqual(8, explain_output.op_count) 226 227 228class TestCustomBackendAPI(torch._dynamo.test_case.TestCase): 229 """Test APIs documented by https://pytorch.org/docs/main/torch.compiler_custom_backends.html""" 230 231 def test_register_backend_api(self): 232 from torch._dynamo import register_backend 233 234 backend_run = False 235 236 @register_backend 237 def my_custom_backend(gm, example_inputs): 238 nonlocal backend_run 239 backend_run = True 240 return gm.forward 241 242 def f(x): 243 return torch.relu(x) 244 245 opt_f = torch.compile(f, backend="my_custom_backend") 246 opt_f(torch.randn(3, 3)) 247 self.assertTrue(backend_run) 248 249 def test_aot_autograd_api(self): 250 from functorch.compile import make_boxed_func 251 from torch._dynamo.backends.common import aot_autograd 252 253 backend_run = False 254 255 def my_compiler(gm, example_inputs): 256 nonlocal backend_run 257 backend_run = True 258 return make_boxed_func(gm.forward) 259 260 my_backend = aot_autograd(fw_compiler=my_compiler) 261 262 def f(x): 263 return torch.relu(x) 264 265 opt_f = torch.compile(f, backend=my_backend) 266 opt_f(torch.randn(3, 3)) 267 self.assertTrue(backend_run) 268 269 def test_lookup_backend(self): 270 from torch._dynamo import list_backends, lookup_backend 271 272 backends = list_backends() 273 backend_run = False 274 275 def my_compiler(gm, example_inputs): 276 nonlocal backend_run 277 backend_run = True 278 try: 279 trt_compiled = lookup_backend("tensorrt")(gm, example_inputs) 280 if trt_compiled is not None: 281 return trt_compiled 282 except Exception: 283 pass 284 # first backend failed, try something else... 285 try: 286 inductor_compiled = lookup_backend("inductor")(gm, example_inputs) 287 if inductor_compiled is not None: 288 return inductor_compiled 289 except Exception: 290 pass 291 return gm.forward 292 293 def f(x): 294 return torch.relu(x) 295 296 opt_f = torch.compile(f, backend=my_compiler) 297 opt_f(torch.randn(3, 3)) 298 self.assertTrue(backend_run) 299 300 def test_lookup_custom_backend(self): 301 from torch._dynamo import list_backends 302 303 backends_group = "torch_dynamo_backends" 304 name = "mycustombackend" 305 306 mock_3_9 = MagicMock() 307 mock_3_9.load.return_value = lambda: "mocked 3.9" 308 mock_3_9.name = name 309 310 mock_3_10 = MagicMock() 311 mock_3_10.load.return_value = lambda: "mocked 3.10" 312 313 def mock_eps(group=None): 314 if sys.version_info < (3, 10): 315 return {backends_group: [mock_3_9]} 316 else: 317 assert group == backends_group, group 318 mock_group = MagicMock() 319 mock_group.names = [name] 320 mock_group[name] = mock_3_10 321 # mock_group[name].load.return_value = lambda: "mocked 3.10" 322 return mock_group 323 324 with patch("importlib.metadata.entry_points", mock_eps): 325 from torch._dynamo.backends import registry 326 327 registry._lazy_import.cache_clear() 328 registry._discover_entrypoint_backends.cache_clear() 329 330 backends = list_backends() 331 assert name in backends, (name, backends) 332 333 def test_backend_recompilation(self): 334 def fn(x): 335 return x + x 336 337 input = torch.tensor(2.0) 338 339 opt_fn = torch.compile( 340 fn, backend="inductor", options={"_raise_error_for_testing": False} 341 ) 342 opt_fn(input) 343 with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): 344 opt_fn = torch.compile( 345 fn, backend="inductor", options={"_raise_error_for_testing": True} 346 ) 347 opt_fn(input) 348 349 350if __name__ == "__main__": 351 from torch._dynamo.test_case import run_tests 352 353 run_tests() 354