1# Owner(s): ["oncall: export"] 2 3import torch 4from torch._dynamo.test_case import TestCase 5from torch._export.tools import report_exportability 6from torch.testing._internal.common_utils import run_tests 7 8 9torch.library.define( 10 "testlib::op_missing_meta", 11 "(Tensor(a!) x, Tensor(b!) z) -> Tensor", 12 tags=torch.Tag.pt2_compliant_tag, 13) 14 15 16@torch.library.impl("testlib::op_missing_meta", "cpu") 17@torch._dynamo.disable 18def op_missing_meta(x, z): 19 x.add_(5) 20 z.add_(5) 21 return x + z 22 23 24class TestExportTools(TestCase): 25 def test_report_exportability_basic(self): 26 class Module(torch.nn.Module): 27 def forward(self, x, y): 28 return x[0] + y 29 30 f = Module() 31 inp = ([torch.ones(1, 3)], torch.ones(1, 3)) 32 33 report = report_exportability(f, inp) 34 self.assertTrue(len(report) == 1) 35 self.assertTrue(report[""] is None) 36 37 def test_report_exportability_with_issues(self): 38 class Unsupported(torch.nn.Module): 39 def forward(self, x): 40 return torch.ops.testlib.op_missing_meta(x, x.cos()) 41 42 class Supported(torch.nn.Module): 43 def forward(self, x): 44 return x.sin() 45 46 class Module(torch.nn.Module): 47 def __init__(self) -> None: 48 super().__init__() 49 self.unsupported = Unsupported() 50 self.supported = Supported() 51 52 def forward(self, x): 53 y = torch.nonzero(x) 54 return self.unsupported(y) + self.supported(y) 55 56 f = Module() 57 inp = (torch.ones(4, 4),) 58 59 report = report_exportability(f, inp, strict=False, pre_dispatch=True) 60 61 self.assertTrue(report[""] is not None) 62 self.assertTrue(report["unsupported"] is not None) 63 self.assertTrue(report["supported"] is None) 64 65 66if __name__ == "__main__": 67 run_tests() 68