xref: /aosp_15_r20/external/pytorch/test/package/test_load_bc_packages.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from pathlib import Path
4from unittest import skipIf
5
6from torch.package import PackageImporter
7from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
8
9
10try:
11    from .common import PackageTestCase
12except ImportError:
13    # Support the case where we run this file directly.
14    from common import PackageTestCase
15
16packaging_directory = f"{Path(__file__).parent}/package_bc"
17
18
19class TestLoadBCPackages(PackageTestCase):
20    """Tests for checking loading has backwards compatiblity"""
21
22    @skipIf(
23        IS_FBCODE or IS_SANDCASTLE,
24        "Tests that use temporary files are disabled in fbcode",
25    )
26    def test_load_bc_packages_nn_module(self):
27        """Tests for backwards compatible nn module"""
28        importer1 = PackageImporter(f"{packaging_directory}/test_nn_module.pt")
29        loaded1 = importer1.load_pickle("nn_module", "nn_module.pkl")
30
31    @skipIf(
32        IS_FBCODE or IS_SANDCASTLE,
33        "Tests that use temporary files are disabled in fbcode",
34    )
35    def test_load_bc_packages_torchscript_module(self):
36        """Tests for backwards compatible torchscript module"""
37        importer2 = PackageImporter(f"{packaging_directory}/test_torchscript_module.pt")
38        loaded2 = importer2.load_pickle("torchscript_module", "torchscript_module.pkl")
39
40    @skipIf(
41        IS_FBCODE or IS_SANDCASTLE,
42        "Tests that use temporary files are disabled in fbcode",
43    )
44    def test_load_bc_packages_fx_module(self):
45        """Tests for backwards compatible fx module"""
46        importer3 = PackageImporter(f"{packaging_directory}/test_fx_module.pt")
47        loaded3 = importer3.load_pickle("fx_module", "fx_module.pkl")
48
49
50if __name__ == "__main__":
51    run_tests()
52