1# Owner(s): ["oncall: package/deploy"] 2 3from pathlib import Path 4from unittest import skipIf 5 6from torch.package import PackageImporter 7from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests 8 9 10try: 11 from .common import PackageTestCase 12except ImportError: 13 # Support the case where we run this file directly. 14 from common import PackageTestCase 15 16packaging_directory = f"{Path(__file__).parent}/package_bc" 17 18 19class TestLoadBCPackages(PackageTestCase): 20 """Tests for checking loading has backwards compatiblity""" 21 22 @skipIf( 23 IS_FBCODE or IS_SANDCASTLE, 24 "Tests that use temporary files are disabled in fbcode", 25 ) 26 def test_load_bc_packages_nn_module(self): 27 """Tests for backwards compatible nn module""" 28 importer1 = PackageImporter(f"{packaging_directory}/test_nn_module.pt") 29 loaded1 = importer1.load_pickle("nn_module", "nn_module.pkl") 30 31 @skipIf( 32 IS_FBCODE or IS_SANDCASTLE, 33 "Tests that use temporary files are disabled in fbcode", 34 ) 35 def test_load_bc_packages_torchscript_module(self): 36 """Tests for backwards compatible torchscript module""" 37 importer2 = PackageImporter(f"{packaging_directory}/test_torchscript_module.pt") 38 loaded2 = importer2.load_pickle("torchscript_module", "torchscript_module.pkl") 39 40 @skipIf( 41 IS_FBCODE or IS_SANDCASTLE, 42 "Tests that use temporary files are disabled in fbcode", 43 ) 44 def test_load_bc_packages_fx_module(self): 45 """Tests for backwards compatible fx module""" 46 importer3 = PackageImporter(f"{packaging_directory}/test_fx_module.pt") 47 loaded3 = importer3.load_pickle("fx_module", "fx_module.pkl") 48 49 50if __name__ == "__main__": 51 run_tests() 52