1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6import zipfile 7from typing import Union 8 9import torch 10from torch.testing import FileCheck 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16from torch.testing._internal.jit_utils import JitTestCase 17 18 19if __name__ == "__main__": 20 raise RuntimeError( 21 "This test file is not meant to be run directly, use:\n\n" 22 "\tpython test/test_jit.py TESTNAME\n\n" 23 "instead." 24 ) 25 26 27class TestUpgraders(JitTestCase): 28 def _load_model_version(self, loaded_model): 29 buffer = io.BytesIO() 30 torch.jit.save(loaded_model, buffer) 31 buffer.seek(0) 32 zipped_model = zipfile.ZipFile(buffer) 33 # there was a change in how we store version number 34 # in a package between version 3 and 7. 35 # So we have to check for both. 36 try: 37 version = int(zipped_model.read("archive/version").decode("utf-8")) 38 return version 39 except KeyError: 40 version = int(zipped_model.read("archive/.data/version").decode("utf-8")) 41 return version 42 43 # TODO (tugsuu) We should ideally be generating this test cases. 44 def test_populated_upgrader_graph(self): 45 @torch.jit.script 46 def f(): 47 return 0 48 49 buffer = io.BytesIO() 50 torch.jit.save(f, buffer) 51 buffer.seek(0) 52 torch.jit.load(buffer) 53 upgraders_size = torch._C._get_upgraders_map_size() 54 upgraders_dump = torch._C._dump_upgraders_map() 55 # make sure we only populate the upgrader map only once 56 # so we load it again and make sure the upgrader map has 57 # same content 58 buffer.seek(0) 59 torch.jit.load(buffer) 60 upgraders_size_second_time = torch._C._get_upgraders_map_size() 61 upgraders_dump_second_time = torch._C._dump_upgraders_map() 62 self.assertTrue(upgraders_size == upgraders_size_second_time) 63 self.assertTrue(upgraders_dump == upgraders_dump_second_time) 64 65 def test_add_value_to_version_map(self): 66 map_before_test = torch._C._get_operator_version_map() 67 68 upgrader_bumped_version = 3 69 upgrader_name = "_test_serialization_subcmul_0_2" 70 upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" 71 dummy_entry = torch._C._UpgraderEntry( 72 upgrader_bumped_version, upgrader_name, upgrader_schema 73 ) 74 75 torch._C._test_only_add_entry_to_op_version_map( 76 "aten::_test_serialization_subcmul", dummy_entry 77 ) 78 map_after_test = torch._C._get_operator_version_map() 79 self.assertTrue("aten::_test_serialization_subcmul" in map_after_test) 80 self.assertTrue(len(map_after_test) - len(map_before_test) == 1) 81 torch._C._test_only_remove_entry_to_op_version_map( 82 "aten::_test_serialization_subcmul" 83 ) 84 map_after_remove_test = torch._C._get_operator_version_map() 85 self.assertTrue( 86 "aten::_test_serialization_subcmul" not in map_after_remove_test 87 ) 88 self.assertEqual(len(map_after_remove_test), len(map_before_test)) 89 90 def test_populated_test_upgrader_graph(self): 91 @torch.jit.script 92 def f(): 93 return 0 94 95 buffer = io.BytesIO() 96 torch.jit.save(f, buffer) 97 buffer.seek(0) 98 torch.jit.load(buffer) 99 100 # upgrader map should have populated now 101 upgraders_size = torch._C._get_upgraders_map_size() 102 103 test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())} 104 torch._C._test_only_populate_upgraders(test_map) 105 upgraders_size_after_test = torch._C._get_upgraders_map_size() 106 self.assertEqual(upgraders_size_after_test - upgraders_size, 2) 107 upgraders_dump = torch._C._dump_upgraders_map() 108 self.assertTrue("a" in upgraders_dump) 109 self.assertTrue("c" in upgraders_dump) 110 111 torch._C._test_only_remove_upgraders(test_map) 112 upgraders_size_after_remove_test = torch._C._get_upgraders_map_size() 113 self.assertTrue(upgraders_size_after_remove_test == upgraders_size) 114 upgraders_dump_after_remove_test = torch._C._dump_upgraders_map() 115 self.assertTrue("a" not in upgraders_dump_after_remove_test) 116 self.assertTrue("c" not in upgraders_dump_after_remove_test) 117 118 def test_aten_div_tensor_at_3(self): 119 model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt" 120 loaded_model = torch.jit.load(model_path) 121 # there are 3 aten::div in this model 122 # And the upgrader for aten::div uses two 123 # div's because of if/else branch 124 FileCheck().check("prim::If").run(loaded_model.graph) 125 FileCheck().check_count("aten::div", 6).run(loaded_model.graph) 126 127 buffer = io.BytesIO() 128 torch.jit.save(loaded_model, buffer) 129 buffer.seek(0) 130 version = self._load_model_version(loaded_model) 131 self.assertTrue(version == 4) 132 loaded_model_twice = torch.jit.load(buffer) 133 # we check by its code because graph variable names 134 # can be different every time 135 self.assertEqual(loaded_model.code, loaded_model_twice.code) 136 137 def test_aten_full_other_variants(self): 138 def test_func(): 139 a = torch.full([4, 5, 6], 4, names=["a", "b", "c"], dtype=torch.int64) 140 return a 141 142 scripted_func = torch.jit.script(test_func) 143 buffer = io.BytesIO() 144 torch.jit.save(scripted_func, buffer) 145 146 current_flag_value = torch._C._get_version_calculator_flag() 147 # calculate based on old version 148 torch._C._calculate_package_version_based_on_upgraders(False) 149 buffer.seek(0) 150 loaded_func = torch.jit.load(buffer) 151 version = self._load_model_version(loaded_func) 152 self.assertTrue(version == 5) 153 154 # calculate based on new version 155 torch._C._calculate_package_version_based_on_upgraders(True) 156 buffer.seek(0) 157 loaded_func = torch.jit.load(buffer) 158 version = self._load_model_version(loaded_func) 159 self.assertTrue(version == 5) 160 161 # make sure we preserve old behaviou 162 torch._C._calculate_package_version_based_on_upgraders(current_flag_value) 163 164 def test_aten_linspace(self): 165 model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl" 166 loaded_model = torch.jit.load(model_path) 167 sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) 168 for a, b in sample_inputs: 169 output_with_step, output_without_step = loaded_model(a, b) 170 # when no step is given, should have used 100 171 self.assertTrue(output_without_step.size(dim=0) == 100) 172 self.assertTrue(output_with_step.size(dim=0) == 5) 173 174 version = self._load_model_version(loaded_model) 175 self.assertTrue(version == 8) 176 177 def test_aten_linspace_out(self): 178 model_path = ( 179 pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl" 180 ) 181 loaded_model = torch.jit.load(model_path) 182 sample_inputs = ( 183 (3, 10, torch.empty((100,), dtype=torch.int64)), 184 (-10, 10, torch.empty((100,), dtype=torch.int64)), 185 (4.0, 6.0, torch.empty((100,), dtype=torch.float64)), 186 (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)), 187 ) 188 for a, b, c in sample_inputs: 189 output = loaded_model(a, b, c) 190 # when no step is given, should have used 100 191 self.assertTrue(output.size(dim=0) == 100) 192 193 version = self._load_model_version(loaded_model) 194 self.assertTrue(version == 8) 195 196 def test_aten_logspace(self): 197 model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl" 198 loaded_model = torch.jit.load(model_path) 199 sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) 200 for a, b in sample_inputs: 201 output_with_step, output_without_step = loaded_model(a, b) 202 # when no step is given, should have used 100 203 self.assertTrue(output_without_step.size(dim=0) == 100) 204 self.assertTrue(output_with_step.size(dim=0) == 5) 205 206 version = self._load_model_version(loaded_model) 207 self.assertTrue(version == 9) 208 209 def test_aten_logspace_out(self): 210 model_path = ( 211 pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl" 212 ) 213 loaded_model = torch.jit.load(model_path) 214 sample_inputs = ( 215 (3, 10, torch.empty((100,), dtype=torch.int64)), 216 (-10, 10, torch.empty((100,), dtype=torch.int64)), 217 (4.0, 6.0, torch.empty((100,), dtype=torch.float64)), 218 (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)), 219 ) 220 for a, b, c in sample_inputs: 221 output = loaded_model(a, b, c) 222 # when no step is given, should have used 100 223 self.assertTrue(output.size(dim=0) == 100) 224 225 version = self._load_model_version(loaded_model) 226 self.assertTrue(version == 9) 227 228 def test_aten_test_serialization(self): 229 model_path = ( 230 pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt" 231 ) 232 233 # add test version entry to the version map 234 upgrader_bumped_version = 3 235 upgrader_name = "_test_serialization_subcmul_0_2" 236 upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" 237 dummy_entry = torch._C._UpgraderEntry( 238 upgrader_bumped_version, upgrader_name, upgrader_schema 239 ) 240 241 torch._C._test_only_add_entry_to_op_version_map( 242 "aten::_test_serialization_subcmul", dummy_entry 243 ) 244 245 # add test upgrader in the upgraders map 246 @torch.jit.script 247 def _test_serialization_subcmul_0_2( 248 self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2 249 ) -> torch.Tensor: 250 return other - (self * alpha) 251 252 torch._C._test_only_populate_upgraders( 253 { 254 "_test_serialization_subcmul_0_2": str( 255 _test_serialization_subcmul_0_2.graph 256 ) 257 } 258 ) 259 260 # test if the server is able to find the test upgraders and apply to IR 261 loaded_model = torch.jit.load(model_path) 262 FileCheck().check_count("aten::mul", 2).run(loaded_model.graph) 263 FileCheck().check_count("aten::sub", 2).run(loaded_model.graph) 264 265 buffer = io.BytesIO() 266 torch.jit.save(loaded_model, buffer) 267 buffer.seek(0) 268 version = self._load_model_version(loaded_model) 269 self.assertTrue(version == 3) 270 loaded_model_twice = torch.jit.load(buffer) 271 # we check by its' code because graph variable names 272 # can be different every time 273 self.assertEqual(loaded_model.code, loaded_model_twice.code) 274 torch._C._test_only_remove_entry_to_op_version_map( 275 "aten::_test_serialization_subcmul" 276 ) 277 torch._C._test_only_remove_upgraders( 278 { 279 "_test_serialization_subcmul_0_2": str( 280 _test_serialization_subcmul_0_2.graph 281 ) 282 } 283 ) 284 285 def test_aten_div_scalar_at_3(self): 286 model_path = ( 287 pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt" 288 ) 289 loaded_model = torch.jit.load(model_path) 290 FileCheck().check("prim::If").run(loaded_model.graph) 291 FileCheck().check_count("aten::div", 2).run(loaded_model.graph) 292 293 buffer = io.BytesIO() 294 torch.jit.save(loaded_model, buffer) 295 buffer.seek(0) 296 version = self._load_model_version(loaded_model) 297 self.assertEqual(version, 4) 298 loaded_model_twice = torch.jit.load(buffer) 299 300 self.assertEqual( 301 loaded_model(torch.Tensor([5.0, 3.0]), 2.0), 302 loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0), 303 ) 304 305 def test_aten_div_tensor_out_at_3(self): 306 model_path = ( 307 pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt" 308 ) 309 loaded_model = torch.jit.load(model_path) 310 FileCheck().check("prim::If").run(loaded_model.graph) 311 FileCheck().check_count("aten::div", 2).run(loaded_model.graph) 312 313 buffer = io.BytesIO() 314 torch.jit.save(loaded_model, buffer) 315 buffer.seek(0) 316 version = self._load_model_version(loaded_model) 317 self.assertTrue(version == 4) 318 loaded_model_twice = torch.jit.load(buffer) 319 # we check by its' code because graph variable names 320 # can be different every time 321 self.assertEqual(loaded_model.code, loaded_model_twice.code) 322 323 def test_aten_full_at_4(self): 324 model_path = ( 325 pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt" 326 ) 327 loaded_model = torch.jit.load(model_path) 328 FileCheck().check_count("aten::Float", 1).run(loaded_model.graph) 329 FileCheck().check_count("aten::full", 2).run(loaded_model.graph) 330 331 buffer = io.BytesIO() 332 torch.jit.save(loaded_model, buffer) 333 buffer.seek(0) 334 version = self._load_model_version(loaded_model) 335 self.assertTrue(version == 5) 336 loaded_model_twice = torch.jit.load(buffer) 337 # we check by its' code because graph variable names 338 # can be different every time 339 self.assertEqual(loaded_model.code, loaded_model_twice.code) 340 341 def test_aten_full_out_at_4(self): 342 model_path = ( 343 pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt" 344 ) 345 loaded_model = torch.jit.load(model_path) 346 FileCheck().check_count("aten::full", 5).run(loaded_model.graph) 347 version = self._load_model_version(loaded_model) 348 self.assertTrue(version == 5) 349