xref: /aosp_15_r20/external/pytorch/test/export/test_tools.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3import torch
4from torch._dynamo.test_case import TestCase
5from torch._export.tools import report_exportability
6from torch.testing._internal.common_utils import run_tests
7
8
9torch.library.define(
10    "testlib::op_missing_meta",
11    "(Tensor(a!) x, Tensor(b!) z) -> Tensor",
12    tags=torch.Tag.pt2_compliant_tag,
13)
14
15
16@torch.library.impl("testlib::op_missing_meta", "cpu")
17@torch._dynamo.disable
18def op_missing_meta(x, z):
19    x.add_(5)
20    z.add_(5)
21    return x + z
22
23
24class TestExportTools(TestCase):
25    def test_report_exportability_basic(self):
26        class Module(torch.nn.Module):
27            def forward(self, x, y):
28                return x[0] + y
29
30        f = Module()
31        inp = ([torch.ones(1, 3)], torch.ones(1, 3))
32
33        report = report_exportability(f, inp)
34        self.assertTrue(len(report) == 1)
35        self.assertTrue(report[""] is None)
36
37    def test_report_exportability_with_issues(self):
38        class Unsupported(torch.nn.Module):
39            def forward(self, x):
40                return torch.ops.testlib.op_missing_meta(x, x.cos())
41
42        class Supported(torch.nn.Module):
43            def forward(self, x):
44                return x.sin()
45
46        class Module(torch.nn.Module):
47            def __init__(self) -> None:
48                super().__init__()
49                self.unsupported = Unsupported()
50                self.supported = Supported()
51
52            def forward(self, x):
53                y = torch.nonzero(x)
54                return self.unsupported(y) + self.supported(y)
55
56        f = Module()
57        inp = (torch.ones(4, 4),)
58
59        report = report_exportability(f, inp, strict=False, pre_dispatch=True)
60
61        self.assertTrue(report[""] is not None)
62        self.assertTrue(report["unsupported"] is not None)
63        self.assertTrue(report["supported"] is None)
64
65
66if __name__ == "__main__":
67    run_tests()
68