xref: /aosp_15_r20/external/pytorch/test/export/test_db.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3import copy
4import unittest
5
6import torch._dynamo as torchdynamo
7from torch._export.db.case import ExportCase, SupportLevel
8from torch._export.db.examples import (
9    filter_examples_by_support_level,
10    get_rewrite_cases,
11)
12from torch.export import export
13from torch.testing._internal.common_utils import (
14    instantiate_parametrized_tests,
15    IS_WINDOWS,
16    parametrize,
17    run_tests,
18    TestCase,
19)
20
21
22@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
23@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
24class ExampleTests(TestCase):
25    # TODO Maybe we should make this tests actually show up in a file?
26    @parametrize(
27        "name,case",
28        filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
29        name_fn=lambda name, case: f"case_{name}",
30    )
31    def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
32        model = case.model
33
34        args_export = case.example_args
35        kwargs_export = case.example_kwargs
36        args_model = copy.deepcopy(args_export)
37        kwargs_model = copy.deepcopy(kwargs_export)
38        exported_program = export(
39            model,
40            args_export,
41            kwargs_export,
42            dynamic_shapes=case.dynamic_shapes,
43        )
44        exported_program.graph_module.print_readable()
45
46        self.assertEqual(
47            exported_program.module()(*args_export, **kwargs_export),
48            model(*args_model, **kwargs_model),
49        )
50
51        if case.extra_args is not None:
52            args = case.extra_args
53            args_model = copy.deepcopy(args)
54            self.assertEqual(
55                exported_program.module()(*args),
56                model(*args_model),
57            )
58
59    @parametrize(
60        "name,case",
61        filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
62        name_fn=lambda name, case: f"case_{name}",
63    )
64    def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
65        model = case.model
66        # pyre-ignore
67        with self.assertRaises(
68            (torchdynamo.exc.Unsupported, AssertionError, RuntimeError)
69        ):
70            export(
71                model,
72                case.example_args,
73                case.example_kwargs,
74                dynamic_shapes=case.dynamic_shapes,
75            )
76
77    exportdb_not_supported_rewrite_cases = [
78        (name, rewrite_case)
79        for name, case in filter_examples_by_support_level(
80            SupportLevel.NOT_SUPPORTED_YET
81        ).items()
82        for rewrite_case in get_rewrite_cases(case)
83    ]
84    if exportdb_not_supported_rewrite_cases:
85
86        @parametrize(
87            "name,rewrite_case",
88            exportdb_not_supported_rewrite_cases,
89            name_fn=lambda name, case: f"case_{name}_{case.name}",
90        )
91        def test_exportdb_not_supported_rewrite(
92            self, name: str, rewrite_case: ExportCase
93        ) -> None:
94            # pyre-ignore
95            export(
96                rewrite_case.model,
97                rewrite_case.example_args,
98                rewrite_case.example_kwargs,
99                dynamic_shapes=rewrite_case.dynamic_shapes,
100            )
101
102
103instantiate_parametrized_tests(ExampleTests)
104
105
106if __name__ == "__main__":
107    run_tests()
108