1# Owner(s): ["oncall: export"] 2 3import copy 4import unittest 5 6import torch._dynamo as torchdynamo 7from torch._export.db.case import ExportCase, SupportLevel 8from torch._export.db.examples import ( 9 filter_examples_by_support_level, 10 get_rewrite_cases, 11) 12from torch.export import export 13from torch.testing._internal.common_utils import ( 14 instantiate_parametrized_tests, 15 IS_WINDOWS, 16 parametrize, 17 run_tests, 18 TestCase, 19) 20 21 22@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 23@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") 24class ExampleTests(TestCase): 25 # TODO Maybe we should make this tests actually show up in a file? 26 @parametrize( 27 "name,case", 28 filter_examples_by_support_level(SupportLevel.SUPPORTED).items(), 29 name_fn=lambda name, case: f"case_{name}", 30 ) 31 def test_exportdb_supported(self, name: str, case: ExportCase) -> None: 32 model = case.model 33 34 args_export = case.example_args 35 kwargs_export = case.example_kwargs 36 args_model = copy.deepcopy(args_export) 37 kwargs_model = copy.deepcopy(kwargs_export) 38 exported_program = export( 39 model, 40 args_export, 41 kwargs_export, 42 dynamic_shapes=case.dynamic_shapes, 43 ) 44 exported_program.graph_module.print_readable() 45 46 self.assertEqual( 47 exported_program.module()(*args_export, **kwargs_export), 48 model(*args_model, **kwargs_model), 49 ) 50 51 if case.extra_args is not None: 52 args = case.extra_args 53 args_model = copy.deepcopy(args) 54 self.assertEqual( 55 exported_program.module()(*args), 56 model(*args_model), 57 ) 58 59 @parametrize( 60 "name,case", 61 filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(), 62 name_fn=lambda name, case: f"case_{name}", 63 ) 64 def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None: 65 model = case.model 66 # pyre-ignore 67 with self.assertRaises( 68 (torchdynamo.exc.Unsupported, AssertionError, RuntimeError) 69 ): 70 export( 71 model, 72 case.example_args, 73 case.example_kwargs, 74 dynamic_shapes=case.dynamic_shapes, 75 ) 76 77 exportdb_not_supported_rewrite_cases = [ 78 (name, rewrite_case) 79 for name, case in filter_examples_by_support_level( 80 SupportLevel.NOT_SUPPORTED_YET 81 ).items() 82 for rewrite_case in get_rewrite_cases(case) 83 ] 84 if exportdb_not_supported_rewrite_cases: 85 86 @parametrize( 87 "name,rewrite_case", 88 exportdb_not_supported_rewrite_cases, 89 name_fn=lambda name, case: f"case_{name}_{case.name}", 90 ) 91 def test_exportdb_not_supported_rewrite( 92 self, name: str, rewrite_case: ExportCase 93 ) -> None: 94 # pyre-ignore 95 export( 96 rewrite_case.model, 97 rewrite_case.example_args, 98 rewrite_case.example_kwargs, 99 dynamic_shapes=rewrite_case.dynamic_shapes, 100 ) 101 102 103instantiate_parametrized_tests(ExampleTests) 104 105 106if __name__ == "__main__": 107 run_tests() 108