1# Owner(s): ["module: functorch"] 2import torch 3import torch._dynamo 4import torch._functorch 5import torch._inductor 6import torch._inductor.decomposition 7from torch._higher_order_ops.torchbind import enable_torchbind_tracing 8from torch._inductor.test_case import run_tests, TestCase 9from torch.testing._internal.torchbind_impls import init_torchbind_implementations 10 11 12class TestTorchbind(TestCase): 13 def setUp(self): 14 super().setUp() 15 init_torchbind_implementations() 16 17 def get_exported_model(self): 18 """ 19 Returns the ExportedProgram, example inputs, and result from calling the 20 eager model with those inputs 21 """ 22 23 class M(torch.nn.Module): 24 def __init__(self) -> None: 25 super().__init__() 26 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 27 self.b = torch.randn(2, 3) 28 29 def forward(self, x): 30 x = x + self.b 31 a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x) 32 y = a[0] + a[1] 33 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) 34 return x + b 35 36 m = M() 37 inputs = (torch.ones(2, 3),) 38 orig_res = m(*inputs) 39 40 # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet 41 with enable_torchbind_tracing(): 42 ep = torch.export.export(m, inputs, strict=False) 43 44 return ep, inputs, orig_res 45 46 def test_torchbind_inductor(self): 47 ep, inputs, orig_res = self.get_exported_model() 48 compiled = torch._inductor.compile(ep.module(), inputs) 49 50 new_res = compiled(*inputs) 51 self.assertTrue(torch.allclose(orig_res, new_res)) 52 53 54if __name__ == "__main__": 55 run_tests() 56