xref: /aosp_15_r20/external/pytorch/test/dynamo/test_backends.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import sys
3import unittest
4from unittest.mock import MagicMock, patch
5
6import torch
7import torch._dynamo
8import torch._dynamo.backends
9import torch._dynamo.test_case
10from torch._dynamo.backends.debugging import ExplainWithBackend
11from torch._dynamo.backends.onnxrt import has_onnxruntime
12from torch._dynamo.backends.tvm import has_tvm
13from torch._dynamo.testing import same
14from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module
15from torch.testing._internal.inductor_utils import HAS_CUDA
16
17
18requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
19
20
21class Seq(torch.nn.Module):
22    def __init__(self) -> None:
23        super().__init__()
24        self.layers = torch.nn.Sequential(
25            torch.nn.Linear(10, 10),
26            torch.nn.ReLU(),
27            torch.nn.Linear(10, 10),
28            torch.nn.Sigmoid(),
29        )
30
31    def forward(self, x):
32        return self.layers(x)
33
34
35class Conv_Bn_Relu(torch.nn.Module):
36    def __init__(self, in_channels, out_channels, **kwargs):
37        super().__init__()
38        self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
39        self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
40        self.relu = torch.nn.ReLU()
41
42    def forward(self, x):
43        return self.relu(self.bn(self.conv(x)))
44
45
46class TestOptimizations(torch._dynamo.test_case.TestCase):
47    def test_example_inputs(self):
48        def fn(a, bc, d):
49            b, c = bc
50            return a / d - b / c
51
52        def compiler_fn(graph, example_inputs):
53            nonlocal r1
54            r1 = graph(*example_inputs)[0]
55            return graph.forward
56
57        a = torch.empty(2).fill_(1)
58        b = torch.empty(2).fill_(2)
59        c = torch.empty(2).fill_(3)
60        d = 4
61        r1 = None
62        r2 = fn(a, (b, c), d)
63        opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
64        r3 = opt_fn(a, (b, c), d)
65
66        self.assertIsNotNone(r1)
67        self.assertEqual(r1.size(), r2.size())
68        self.assertEqual(r1.stride(), r2.stride())
69        self.assertEqual(r1.dtype, r2.dtype)
70
71        self.assertEqual(r1.size(), r3.size())
72        self.assertEqual(r1.stride(), r3.stride())
73        self.assertEqual(r1.dtype, r3.dtype)
74
75    def test_example_inputs_runtime_use(self):
76        def fn(a, bc, d):
77            b, c = bc
78            return a / d - b / c
79
80        def compiler_fn(graph, example_inputs):
81            def fwd(*args):
82                nonlocal r1
83                r = graph.forward(*args)
84                r1 = r[0]
85                return r
86
87            return fwd
88
89        a = torch.empty(2).fill_(1)
90        b = torch.empty(2).fill_(2)
91        c = torch.empty(2).fill_(3)
92        d = 4
93        r1 = None
94        r2 = fn(a, (b, c), d)
95        opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
96        r3 = opt_fn(a, (b, c), d)
97
98        self.assertIsNotNone(r1)
99        self.assertTrue(same(r1, r2))
100        self.assertTrue(same(r1, r3))
101
102    def _check_backend_works(self, backend, options=None):
103        model = Seq().eval()
104        input = torch.randn(2, 10)
105        r1 = model(input)
106        r2 = torch.compile(model, backend=backend, options=options)(input)
107        self.assertTrue(same(r1, r2.float(), tol=0.01))
108
109    def test_eager(self):
110        self._check_backend_works("eager")
111
112    def test_eager_noexcept(self):
113        self._check_backend_works("eager_noexcept")
114
115    @_force_skip_lazy_graph_module()
116    def test_torchscript(self):
117        self._check_backend_works("ts")
118
119    def test_aot_eager(self):
120        self._check_backend_works("aot_eager")
121
122    def test_aot_eager_decomp_partition(self):
123        self._check_backend_works("aot_eager_decomp_partition")
124
125    @_force_skip_lazy_graph_module()
126    def test_aot_ts(self):
127        self._check_backend_works("aot_ts")
128
129    @requires_cuda
130    def test_aot_cudagraphs(self):
131        self._check_backend_works("cudagraphs")
132
133    @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
134    def test_onnxrt(self):
135        self._check_backend_works("onnxrt")
136
137    @unittest.skipIf(not has_tvm(), "requires tvm")
138    def test_tvm(self):
139        self._check_backend_works("tvm")
140        self._check_backend_works("tvm", options={"scheduler": None})
141        self._check_backend_works("tvm", options={"opt_level": 0})
142
143    def test_list_backends(self):
144        self.assertIn("inductor", torch._dynamo.list_backends())
145        self.assertIn("inductor", torch._dynamo.list_backends(exclude_tags=None))
146        self.assertNotIn("eager", torch._dynamo.list_backends())
147        self.assertNotIn("eager", torch._dynamo.list_backends(exclude_tags=["debug"]))
148        self.assertIn("eager", torch._dynamo.list_backends(exclude_tags=[]))
149
150
151class NormalizeIRTests(torch._dynamo.test_case.TestCase):
152    def test_inplace_normalize(self):
153        def fn(a, b):
154            x = torch.cos(a)
155            x += b
156            return torch.sin(x)
157
158        a = torch.randn(10)
159        b = torch.randn(10).to(torch.float64)
160
161        ref = fn(a, b)
162
163        optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
164        res = optimized_fn(a, b)
165        self.assertTrue(same(ref, res))
166
167
168class MPSNotSupportedTest(torch._dynamo.test_case.TestCase):
169    @unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
170    def test_mps_not_supported(self):
171        model = Seq().to("mps")
172        example_input = torch.randn(1, 10).to("mps")
173        self.assertRaises(
174            RuntimeError,
175            lambda: torch.compile(model, backend="inductor")(example_input),
176        )
177
178
179class TestExplainWithBackend(torch._dynamo.test_case.TestCase):
180    def test_explain_with_backend(self):
181        def fn3(x):
182            x = torch.sin(x)
183            torch._dynamo.graph_break()
184            x = torch.sin(x)
185            return x
186
187        def fn2(x):
188            x = torch.cos(x)
189            x = fn3(x)
190            x = torch.cos(x)
191            return x
192
193        def fn1(x):
194            x = torch.tan(x)
195            x = fn2(x)
196            x = torch.tan(x)
197            return x
198
199        def fn(x):
200            x = torch.sigmoid(x)
201            x = fn1(x)
202            x = torch.sigmoid(x)
203            return x
204
205        # Wrap TorchInductor with explain backend
206        eb = ExplainWithBackend("inductor")
207        optimized_fn = torch.compile(fn, backend=eb)
208        input_tensor = torch.randn(5)
209        result = optimized_fn(input_tensor)
210
211        # Check that fn still produces the same output when wrapped by ExplainWithBackend
212        self.assertTrue(torch.allclose(result, fn(input_tensor)))
213
214        # Verify ExplainOutput object contents, output might change but make sure these fields are present
215        explain_output = eb.output()
216        explain_str = str(explain_output)
217        self.assertIn("Graph Count", explain_str)
218        self.assertIn("Graph Break Count", explain_str)
219        self.assertIn("Op Count", explain_str)
220        self.assertIn("Break Reasons", explain_str)
221
222        # Verify that for the given functions above, we report the correct number of graphs, graph breaks, and ops
223        self.assertEqual(8, explain_output.graph_count)
224        self.assertEqual(7, explain_output.graph_break_count)
225        self.assertEqual(8, explain_output.op_count)
226
227
228class TestCustomBackendAPI(torch._dynamo.test_case.TestCase):
229    """Test APIs documented by https://pytorch.org/docs/main/torch.compiler_custom_backends.html"""
230
231    def test_register_backend_api(self):
232        from torch._dynamo import register_backend
233
234        backend_run = False
235
236        @register_backend
237        def my_custom_backend(gm, example_inputs):
238            nonlocal backend_run
239            backend_run = True
240            return gm.forward
241
242        def f(x):
243            return torch.relu(x)
244
245        opt_f = torch.compile(f, backend="my_custom_backend")
246        opt_f(torch.randn(3, 3))
247        self.assertTrue(backend_run)
248
249    def test_aot_autograd_api(self):
250        from functorch.compile import make_boxed_func
251        from torch._dynamo.backends.common import aot_autograd
252
253        backend_run = False
254
255        def my_compiler(gm, example_inputs):
256            nonlocal backend_run
257            backend_run = True
258            return make_boxed_func(gm.forward)
259
260        my_backend = aot_autograd(fw_compiler=my_compiler)
261
262        def f(x):
263            return torch.relu(x)
264
265        opt_f = torch.compile(f, backend=my_backend)
266        opt_f(torch.randn(3, 3))
267        self.assertTrue(backend_run)
268
269    def test_lookup_backend(self):
270        from torch._dynamo import list_backends, lookup_backend
271
272        backends = list_backends()
273        backend_run = False
274
275        def my_compiler(gm, example_inputs):
276            nonlocal backend_run
277            backend_run = True
278            try:
279                trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
280                if trt_compiled is not None:
281                    return trt_compiled
282            except Exception:
283                pass
284            # first backend failed, try something else...
285            try:
286                inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
287                if inductor_compiled is not None:
288                    return inductor_compiled
289            except Exception:
290                pass
291            return gm.forward
292
293        def f(x):
294            return torch.relu(x)
295
296        opt_f = torch.compile(f, backend=my_compiler)
297        opt_f(torch.randn(3, 3))
298        self.assertTrue(backend_run)
299
300    def test_lookup_custom_backend(self):
301        from torch._dynamo import list_backends
302
303        backends_group = "torch_dynamo_backends"
304        name = "mycustombackend"
305
306        mock_3_9 = MagicMock()
307        mock_3_9.load.return_value = lambda: "mocked 3.9"
308        mock_3_9.name = name
309
310        mock_3_10 = MagicMock()
311        mock_3_10.load.return_value = lambda: "mocked 3.10"
312
313        def mock_eps(group=None):
314            if sys.version_info < (3, 10):
315                return {backends_group: [mock_3_9]}
316            else:
317                assert group == backends_group, group
318                mock_group = MagicMock()
319                mock_group.names = [name]
320                mock_group[name] = mock_3_10
321                # mock_group[name].load.return_value = lambda: "mocked 3.10"
322                return mock_group
323
324        with patch("importlib.metadata.entry_points", mock_eps):
325            from torch._dynamo.backends import registry
326
327            registry._lazy_import.cache_clear()
328            registry._discover_entrypoint_backends.cache_clear()
329
330            backends = list_backends()
331            assert name in backends, (name, backends)
332
333    def test_backend_recompilation(self):
334        def fn(x):
335            return x + x
336
337        input = torch.tensor(2.0)
338
339        opt_fn = torch.compile(
340            fn, backend="inductor", options={"_raise_error_for_testing": False}
341        )
342        opt_fn(input)
343        with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
344            opt_fn = torch.compile(
345                fn, backend="inductor", options={"_raise_error_for_testing": True}
346            )
347            opt_fn(input)
348
349
350if __name__ == "__main__":
351    from torch._dynamo.test_case import run_tests
352
353    run_tests()
354