xref: /aosp_15_r20/external/pytorch/test/package/generate_bc_packages.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from pathlib import Path
2
3import torch
4from torch.fx import symbolic_trace
5from torch.package import PackageExporter
6from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
7
8
9packaging_directory = f"{Path(__file__).parent}/package_bc"
10torch.package.package_exporter._gate_torchscript_serialization = False
11
12
13def generate_bc_packages():
14    """Function to create packages for testing backwards compatiblity"""
15    if not IS_FBCODE or IS_SANDCASTLE:
16        from package_a.test_nn_module import TestNnModule
17
18        test_nn_module = TestNnModule()
19        test_torchscript_module = torch.jit.script(TestNnModule())
20        test_fx_module: torch.fx.GraphModule = symbolic_trace(TestNnModule())
21        with PackageExporter(f"{packaging_directory}/test_nn_module.pt") as pe1:
22            pe1.intern("**")
23            pe1.save_pickle("nn_module", "nn_module.pkl", test_nn_module)
24        with PackageExporter(
25            f"{packaging_directory}/test_torchscript_module.pt"
26        ) as pe2:
27            pe2.intern("**")
28            pe2.save_pickle(
29                "torchscript_module", "torchscript_module.pkl", test_torchscript_module
30            )
31        with PackageExporter(f"{packaging_directory}/test_fx_module.pt") as pe3:
32            pe3.intern("**")
33            pe3.save_pickle("fx_module", "fx_module.pkl", test_fx_module)
34
35
36if __name__ == "__main__":
37    generate_bc_packages()
38