xref: /aosp_15_r20/external/pytorch/test/inductor/test_torchbind.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2import torch
3import torch._dynamo
4import torch._functorch
5import torch._inductor
6import torch._inductor.decomposition
7from torch._higher_order_ops.torchbind import enable_torchbind_tracing
8from torch._inductor.test_case import run_tests, TestCase
9from torch.testing._internal.torchbind_impls import init_torchbind_implementations
10
11
12class TestTorchbind(TestCase):
13    def setUp(self):
14        super().setUp()
15        init_torchbind_implementations()
16
17    def get_exported_model(self):
18        """
19        Returns the ExportedProgram, example inputs, and result from calling the
20        eager model with those inputs
21        """
22
23        class M(torch.nn.Module):
24            def __init__(self) -> None:
25                super().__init__()
26                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
27                self.b = torch.randn(2, 3)
28
29            def forward(self, x):
30                x = x + self.b
31                a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
32                y = a[0] + a[1]
33                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
34                return x + b
35
36        m = M()
37        inputs = (torch.ones(2, 3),)
38        orig_res = m(*inputs)
39
40        # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
41        with enable_torchbind_tracing():
42            ep = torch.export.export(m, inputs, strict=False)
43
44        return ep, inputs, orig_res
45
46    def test_torchbind_inductor(self):
47        ep, inputs, orig_res = self.get_exported_model()
48        compiled = torch._inductor.compile(ep.module(), inputs)
49
50        new_res = compiled(*inputs)
51        self.assertTrue(torch.allclose(orig_res, new_res))
52
53
54if __name__ == "__main__":
55    run_tests()
56