1# Owner(s): ["oncall: package/deploy"] 2 3import pickle 4from io import BytesIO 5from textwrap import dedent 6 7from torch.package import PackageExporter, PackageImporter, sys_importer 8from torch.testing._internal.common_utils import run_tests 9 10 11try: 12 from .common import PackageTestCase 13except ImportError: 14 # Support the case where we run this file directly. 15 from common import PackageTestCase 16 17from pathlib import Path 18 19 20packaging_directory = Path(__file__).parent 21 22 23class TestSaveLoad(PackageTestCase): 24 """Core save_* and loading API tests.""" 25 26 def test_saving_source(self): 27 buffer = BytesIO() 28 with PackageExporter(buffer) as he: 29 he.save_source_file("foo", str(packaging_directory / "module_a.py")) 30 he.save_source_file("foodir", str(packaging_directory / "package_a")) 31 buffer.seek(0) 32 hi = PackageImporter(buffer) 33 foo = hi.import_module("foo") 34 s = hi.import_module("foodir.subpackage") 35 self.assertEqual(foo.result, "module_a") 36 self.assertEqual(s.result, "package_a.subpackage") 37 38 def test_saving_string(self): 39 buffer = BytesIO() 40 with PackageExporter(buffer) as he: 41 src = dedent( 42 """\ 43 import math 44 the_math = math 45 """ 46 ) 47 he.save_source_string("my_mod", src) 48 buffer.seek(0) 49 hi = PackageImporter(buffer) 50 m = hi.import_module("math") 51 import math 52 53 self.assertIs(m, math) 54 my_mod = hi.import_module("my_mod") 55 self.assertIs(my_mod.math, math) 56 57 def test_save_module(self): 58 buffer = BytesIO() 59 with PackageExporter(buffer) as he: 60 import module_a 61 import package_a 62 63 he.save_module(module_a.__name__) 64 he.save_module(package_a.__name__) 65 buffer.seek(0) 66 hi = PackageImporter(buffer) 67 module_a_i = hi.import_module("module_a") 68 self.assertEqual(module_a_i.result, "module_a") 69 self.assertIsNot(module_a, module_a_i) 70 package_a_i = hi.import_module("package_a") 71 self.assertEqual(package_a_i.result, "package_a") 72 self.assertIsNot(package_a_i, package_a) 73 74 def test_dunder_imports(self): 75 buffer = BytesIO() 76 with PackageExporter(buffer) as he: 77 import package_b 78 79 obj = package_b.PackageBObject 80 he.intern("**") 81 he.save_pickle("res", "obj.pkl", obj) 82 83 buffer.seek(0) 84 hi = PackageImporter(buffer) 85 loaded_obj = hi.load_pickle("res", "obj.pkl") 86 87 package_b = hi.import_module("package_b") 88 self.assertEqual(package_b.result, "package_b") 89 90 math = hi.import_module("math") 91 self.assertEqual(math.__name__, "math") 92 93 xml_sub_sub_package = hi.import_module("xml.sax.xmlreader") 94 self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader") 95 96 subpackage_1 = hi.import_module("package_b.subpackage_1") 97 self.assertEqual(subpackage_1.result, "subpackage_1") 98 99 subpackage_2 = hi.import_module("package_b.subpackage_2") 100 self.assertEqual(subpackage_2.result, "subpackage_2") 101 102 subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0") 103 self.assertEqual(subsubpackage_0.result, "subsubpackage_0") 104 105 def test_bad_dunder_imports(self): 106 """Test to ensure bad __imports__ don't cause PackageExporter to fail.""" 107 buffer = BytesIO() 108 with PackageExporter(buffer) as e: 109 e.save_source_string( 110 "m", '__import__(these, unresolvable, "things", wont, crash, me)' 111 ) 112 113 def test_save_module_binary(self): 114 f = BytesIO() 115 with PackageExporter(f) as he: 116 import module_a 117 import package_a 118 119 he.save_module(module_a.__name__) 120 he.save_module(package_a.__name__) 121 f.seek(0) 122 hi = PackageImporter(f) 123 module_a_i = hi.import_module("module_a") 124 self.assertEqual(module_a_i.result, "module_a") 125 self.assertIsNot(module_a, module_a_i) 126 package_a_i = hi.import_module("package_a") 127 self.assertEqual(package_a_i.result, "package_a") 128 self.assertIsNot(package_a_i, package_a) 129 130 def test_pickle(self): 131 import package_a.subpackage 132 133 obj = package_a.subpackage.PackageASubpackageObject() 134 obj2 = package_a.PackageAObject(obj) 135 136 buffer = BytesIO() 137 with PackageExporter(buffer) as he: 138 he.intern("**") 139 he.save_pickle("obj", "obj.pkl", obj2) 140 buffer.seek(0) 141 hi = PackageImporter(buffer) 142 143 # check we got dependencies 144 sp = hi.import_module("package_a.subpackage") 145 # check we didn't get other stuff 146 with self.assertRaises(ImportError): 147 hi.import_module("module_a") 148 149 obj_loaded = hi.load_pickle("obj", "obj.pkl") 150 self.assertIsNot(obj2, obj_loaded) 151 self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject) 152 self.assertIsNot( 153 package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject 154 ) 155 156 def test_pickle_long_name_with_protocol_4(self): 157 import package_a.long_name 158 159 container = [] 160 161 # Indirectly grab the function to avoid pasting a 256 character 162 # function into the test 163 package_a.long_name.add_function(container) 164 165 buffer = BytesIO() 166 with PackageExporter(buffer) as exporter: 167 exporter.intern("**") 168 exporter.save_pickle( 169 "container", "container.pkl", container, pickle_protocol=4 170 ) 171 172 buffer.seek(0) 173 importer = PackageImporter(buffer) 174 unpickled_container = importer.load_pickle("container", "container.pkl") 175 self.assertIsNot(container, unpickled_container) 176 self.assertEqual(len(unpickled_container), 1) 177 self.assertEqual(container[0](), unpickled_container[0]()) 178 179 def test_exporting_mismatched_code(self): 180 """ 181 If an object with the same qualified name is loaded from different 182 packages, the user should get an error if they try to re-save the 183 object with the wrong package's source code. 184 """ 185 import package_a.subpackage 186 187 obj = package_a.subpackage.PackageASubpackageObject() 188 obj2 = package_a.PackageAObject(obj) 189 190 b1 = BytesIO() 191 with PackageExporter(b1) as pe: 192 pe.intern("**") 193 pe.save_pickle("obj", "obj.pkl", obj2) 194 195 b1.seek(0) 196 importer1 = PackageImporter(b1) 197 loaded1 = importer1.load_pickle("obj", "obj.pkl") 198 199 b1.seek(0) 200 importer2 = PackageImporter(b1) 201 loaded2 = importer2.load_pickle("obj", "obj.pkl") 202 203 def make_exporter(): 204 pe = PackageExporter(BytesIO(), importer=[importer1, sys_importer]) 205 # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. 206 return pe 207 208 # This should fail. The 'PackageAObject' type defined from 'importer1' 209 # is not necessarily the same 'obj2's version of 'PackageAObject'. 210 pe = make_exporter() 211 with self.assertRaises(pickle.PicklingError): 212 pe.save_pickle("obj", "obj.pkl", obj2) 213 214 # This should also fail. The 'PackageAObject' type defined from 'importer1' 215 # is not necessarily the same as the one defined from 'importer2' 216 pe = make_exporter() 217 with self.assertRaises(pickle.PicklingError): 218 pe.save_pickle("obj", "obj.pkl", loaded2) 219 220 # This should succeed. The 'PackageAObject' type defined from 221 # 'importer1' is a match for the one used by loaded1. 222 pe = make_exporter() 223 pe.save_pickle("obj", "obj.pkl", loaded1) 224 225 def test_save_imported_module(self): 226 """Saving a module that came from another PackageImporter should work.""" 227 import package_a.subpackage 228 229 obj = package_a.subpackage.PackageASubpackageObject() 230 obj2 = package_a.PackageAObject(obj) 231 232 buffer = BytesIO() 233 with PackageExporter(buffer) as exporter: 234 exporter.intern("**") 235 exporter.save_pickle("model", "model.pkl", obj2) 236 237 buffer.seek(0) 238 239 importer = PackageImporter(buffer) 240 imported_obj2 = importer.load_pickle("model", "model.pkl") 241 imported_obj2_module = imported_obj2.__class__.__module__ 242 243 # Should export without error. 244 buffer2 = BytesIO() 245 with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter: 246 exporter.intern("**") 247 exporter.save_module(imported_obj2_module) 248 249 def test_save_imported_module_using_package_importer(self): 250 """Exercise a corner case: re-packaging a module that uses `torch_package_importer`""" 251 import package_a.use_torch_package_importer # noqa: F401 252 253 buffer = BytesIO() 254 with PackageExporter(buffer) as exporter: 255 exporter.intern("**") 256 exporter.save_module("package_a.use_torch_package_importer") 257 258 buffer.seek(0) 259 260 importer = PackageImporter(buffer) 261 262 # Should export without error. 263 buffer2 = BytesIO() 264 with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter: 265 exporter.intern("**") 266 exporter.save_module("package_a.use_torch_package_importer") 267 268 269if __name__ == "__main__": 270 run_tests() 271