1# Owner(s): ["oncall: package/deploy"] 2 3import os 4import zipfile 5from sys import version_info 6from tempfile import TemporaryDirectory 7from textwrap import dedent 8from unittest import skipIf 9 10import torch 11from torch.package import PackageExporter, PackageImporter 12from torch.testing._internal.common_utils import ( 13 IS_FBCODE, 14 IS_SANDCASTLE, 15 IS_WINDOWS, 16 run_tests, 17) 18 19 20try: 21 from torchvision.models import resnet18 22 23 HAS_TORCHVISION = True 24except ImportError: 25 HAS_TORCHVISION = False 26skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") 27 28 29try: 30 from .common import PackageTestCase 31except ImportError: 32 # Support the case where we run this file directly. 33 from common import PackageTestCase 34 35from pathlib import Path 36 37 38packaging_directory = Path(__file__).parent 39 40 41@skipIf( 42 IS_FBCODE or IS_SANDCASTLE or IS_WINDOWS, 43 "Tests that use temporary files are disabled in fbcode", 44) 45class DirectoryReaderTest(PackageTestCase): 46 """Tests use of DirectoryReader as accessor for opened packages.""" 47 48 @skipIfNoTorchVision 49 @skipIf( 50 True, 51 "Does not work with latest TorchVision, see https://github.com/pytorch/pytorch/issues/81115", 52 ) 53 def test_loading_pickle(self): 54 """ 55 Test basic saving and loading of modules and pickles from a DirectoryReader. 56 """ 57 resnet = resnet18() 58 59 filename = self.temp() 60 with PackageExporter(filename) as e: 61 e.intern("**") 62 e.save_pickle("model", "model.pkl", resnet) 63 64 zip_file = zipfile.ZipFile(filename, "r") 65 66 with TemporaryDirectory() as temp_dir: 67 zip_file.extractall(path=temp_dir) 68 importer = PackageImporter(Path(temp_dir) / Path(filename).name) 69 dir_mod = importer.load_pickle("model", "model.pkl") 70 input = torch.rand(1, 3, 224, 224) 71 self.assertEqual(dir_mod(input), resnet(input)) 72 73 def test_loading_module(self): 74 """ 75 Test basic saving and loading of a packages from a DirectoryReader. 76 """ 77 import package_a 78 79 filename = self.temp() 80 with PackageExporter(filename) as e: 81 e.save_module("package_a") 82 83 zip_file = zipfile.ZipFile(filename, "r") 84 85 with TemporaryDirectory() as temp_dir: 86 zip_file.extractall(path=temp_dir) 87 dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) 88 dir_mod = dir_importer.import_module("package_a") 89 self.assertEqual(dir_mod.result, package_a.result) 90 91 def test_loading_has_record(self): 92 """ 93 Test DirectoryReader's has_record(). 94 """ 95 import package_a # noqa: F401 96 97 filename = self.temp() 98 with PackageExporter(filename) as e: 99 e.save_module("package_a") 100 101 zip_file = zipfile.ZipFile(filename, "r") 102 103 with TemporaryDirectory() as temp_dir: 104 zip_file.extractall(path=temp_dir) 105 dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) 106 self.assertTrue(dir_importer.zip_reader.has_record("package_a/__init__.py")) 107 self.assertFalse(dir_importer.zip_reader.has_record("package_a")) 108 109 @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7") 110 def test_resource_reader(self): 111 """Tests DirectoryReader as the base for get_resource_reader.""" 112 filename = self.temp() 113 with PackageExporter(filename) as pe: 114 # Layout looks like: 115 # package 116 # |-- one/ 117 # | |-- a.txt 118 # | |-- b.txt 119 # | |-- c.txt 120 # | +-- three/ 121 # | |-- d.txt 122 # | +-- e.txt 123 # +-- two/ 124 # |-- f.txt 125 # +-- g.txt 126 pe.save_text("one", "a.txt", "hello, a!") 127 pe.save_text("one", "b.txt", "hello, b!") 128 pe.save_text("one", "c.txt", "hello, c!") 129 130 pe.save_text("one.three", "d.txt", "hello, d!") 131 pe.save_text("one.three", "e.txt", "hello, e!") 132 133 pe.save_text("two", "f.txt", "hello, f!") 134 pe.save_text("two", "g.txt", "hello, g!") 135 136 zip_file = zipfile.ZipFile(filename, "r") 137 138 with TemporaryDirectory() as temp_dir: 139 zip_file.extractall(path=temp_dir) 140 importer = PackageImporter(Path(temp_dir) / Path(filename).name) 141 reader_one = importer.get_resource_reader("one") 142 143 # Different behavior from still zipped archives 144 resource_path = os.path.join( 145 Path(temp_dir), Path(filename).name, "one", "a.txt" 146 ) 147 self.assertEqual(reader_one.resource_path("a.txt"), resource_path) 148 149 self.assertTrue(reader_one.is_resource("a.txt")) 150 self.assertEqual( 151 reader_one.open_resource("a.txt").getbuffer(), b"hello, a!" 152 ) 153 self.assertFalse(reader_one.is_resource("three")) 154 reader_one_contents = list(reader_one.contents()) 155 reader_one_contents.sort() 156 self.assertSequenceEqual( 157 reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"] 158 ) 159 160 reader_two = importer.get_resource_reader("two") 161 self.assertTrue(reader_two.is_resource("f.txt")) 162 self.assertEqual( 163 reader_two.open_resource("f.txt").getbuffer(), b"hello, f!" 164 ) 165 reader_two_contents = list(reader_two.contents()) 166 reader_two_contents.sort() 167 self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"]) 168 169 reader_one_three = importer.get_resource_reader("one.three") 170 self.assertTrue(reader_one_three.is_resource("d.txt")) 171 self.assertEqual( 172 reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!" 173 ) 174 reader_one_three_contents = list(reader_one_three.contents()) 175 reader_one_three_contents.sort() 176 self.assertSequenceEqual(reader_one_three_contents, ["d.txt", "e.txt"]) 177 178 self.assertIsNone(importer.get_resource_reader("nonexistent_package")) 179 180 @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7") 181 def test_package_resource_access(self): 182 """Packaged modules should be able to use the importlib.resources API to access 183 resources saved in the package. 184 """ 185 mod_src = dedent( 186 """\ 187 import importlib.resources 188 import my_cool_resources 189 190 def secret_message(): 191 return importlib.resources.read_text(my_cool_resources, 'sekrit.txt') 192 """ 193 ) 194 filename = self.temp() 195 with PackageExporter(filename) as pe: 196 pe.save_source_string("foo.bar", mod_src) 197 pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays") 198 199 zip_file = zipfile.ZipFile(filename, "r") 200 201 with TemporaryDirectory() as temp_dir: 202 zip_file.extractall(path=temp_dir) 203 dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) 204 self.assertEqual( 205 dir_importer.import_module("foo.bar").secret_message(), 206 "my sekrit plays", 207 ) 208 209 @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7") 210 def test_importer_access(self): 211 filename = self.temp() 212 with PackageExporter(filename) as he: 213 he.save_text("main", "main", "my string") 214 he.save_binary("main", "main_binary", b"my string") 215 src = dedent( 216 """\ 217 import importlib 218 import torch_package_importer as resources 219 220 t = resources.load_text('main', 'main') 221 b = resources.load_binary('main', 'main_binary') 222 """ 223 ) 224 he.save_source_string("main", src, is_package=True) 225 226 zip_file = zipfile.ZipFile(filename, "r") 227 228 with TemporaryDirectory() as temp_dir: 229 zip_file.extractall(path=temp_dir) 230 dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) 231 m = dir_importer.import_module("main") 232 self.assertEqual(m.t, "my string") 233 self.assertEqual(m.b, b"my string") 234 235 @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7") 236 def test_resource_access_by_path(self): 237 """ 238 Tests that packaged code can used importlib.resources.path. 239 """ 240 filename = self.temp() 241 with PackageExporter(filename) as e: 242 e.save_binary("string_module", "my_string", b"my string") 243 src = dedent( 244 """\ 245 import importlib.resources 246 import string_module 247 248 with importlib.resources.path(string_module, 'my_string') as path: 249 with open(path, mode='r', encoding='utf-8') as f: 250 s = f.read() 251 """ 252 ) 253 e.save_source_string("main", src, is_package=True) 254 255 zip_file = zipfile.ZipFile(filename, "r") 256 257 with TemporaryDirectory() as temp_dir: 258 zip_file.extractall(path=temp_dir) 259 dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) 260 m = dir_importer.import_module("main") 261 self.assertEqual(m.s, "my string") 262 263 def test_scriptobject_failure_message(self): 264 """ 265 Test basic saving and loading of a ScriptModule in a directory. 266 Currently not supported. 267 """ 268 from package_a.test_module import ModWithTensor 269 270 scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) 271 272 filename = self.temp() 273 with PackageExporter(filename) as e: 274 e.save_pickle("res", "mod.pkl", scripted_mod) 275 276 zip_file = zipfile.ZipFile(filename, "r") 277 278 with self.assertRaisesRegex( 279 RuntimeError, 280 "Loading ScriptObjects from a PackageImporter created from a " 281 "directory is not supported. Use a package archive file instead.", 282 ): 283 with TemporaryDirectory() as temp_dir: 284 zip_file.extractall(path=temp_dir) 285 dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) 286 dir_mod = dir_importer.load_pickle("res", "mod.pkl") 287 288 289if __name__ == "__main__": 290 run_tests() 291