1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport zipfile 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Union 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 14*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 20*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 21*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 22*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 23*da0073e9SAndroid Build Coastguard Worker "instead." 24*da0073e9SAndroid Build Coastguard Worker ) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerclass TestUpgraders(JitTestCase): 28*da0073e9SAndroid Build Coastguard Worker def _load_model_version(self, loaded_model): 29*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 30*da0073e9SAndroid Build Coastguard Worker torch.jit.save(loaded_model, buffer) 31*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 32*da0073e9SAndroid Build Coastguard Worker zipped_model = zipfile.ZipFile(buffer) 33*da0073e9SAndroid Build Coastguard Worker # there was a change in how we store version number 34*da0073e9SAndroid Build Coastguard Worker # in a package between version 3 and 7. 35*da0073e9SAndroid Build Coastguard Worker # So we have to check for both. 36*da0073e9SAndroid Build Coastguard Worker try: 37*da0073e9SAndroid Build Coastguard Worker version = int(zipped_model.read("archive/version").decode("utf-8")) 38*da0073e9SAndroid Build Coastguard Worker return version 39*da0073e9SAndroid Build Coastguard Worker except KeyError: 40*da0073e9SAndroid Build Coastguard Worker version = int(zipped_model.read("archive/.data/version").decode("utf-8")) 41*da0073e9SAndroid Build Coastguard Worker return version 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker # TODO (tugsuu) We should ideally be generating this test cases. 44*da0073e9SAndroid Build Coastguard Worker def test_populated_upgrader_graph(self): 45*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 46*da0073e9SAndroid Build Coastguard Worker def f(): 47*da0073e9SAndroid Build Coastguard Worker return 0 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 50*da0073e9SAndroid Build Coastguard Worker torch.jit.save(f, buffer) 51*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 52*da0073e9SAndroid Build Coastguard Worker torch.jit.load(buffer) 53*da0073e9SAndroid Build Coastguard Worker upgraders_size = torch._C._get_upgraders_map_size() 54*da0073e9SAndroid Build Coastguard Worker upgraders_dump = torch._C._dump_upgraders_map() 55*da0073e9SAndroid Build Coastguard Worker # make sure we only populate the upgrader map only once 56*da0073e9SAndroid Build Coastguard Worker # so we load it again and make sure the upgrader map has 57*da0073e9SAndroid Build Coastguard Worker # same content 58*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 59*da0073e9SAndroid Build Coastguard Worker torch.jit.load(buffer) 60*da0073e9SAndroid Build Coastguard Worker upgraders_size_second_time = torch._C._get_upgraders_map_size() 61*da0073e9SAndroid Build Coastguard Worker upgraders_dump_second_time = torch._C._dump_upgraders_map() 62*da0073e9SAndroid Build Coastguard Worker self.assertTrue(upgraders_size == upgraders_size_second_time) 63*da0073e9SAndroid Build Coastguard Worker self.assertTrue(upgraders_dump == upgraders_dump_second_time) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_add_value_to_version_map(self): 66*da0073e9SAndroid Build Coastguard Worker map_before_test = torch._C._get_operator_version_map() 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker upgrader_bumped_version = 3 69*da0073e9SAndroid Build Coastguard Worker upgrader_name = "_test_serialization_subcmul_0_2" 70*da0073e9SAndroid Build Coastguard Worker upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" 71*da0073e9SAndroid Build Coastguard Worker dummy_entry = torch._C._UpgraderEntry( 72*da0073e9SAndroid Build Coastguard Worker upgrader_bumped_version, upgrader_name, upgrader_schema 73*da0073e9SAndroid Build Coastguard Worker ) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_add_entry_to_op_version_map( 76*da0073e9SAndroid Build Coastguard Worker "aten::_test_serialization_subcmul", dummy_entry 77*da0073e9SAndroid Build Coastguard Worker ) 78*da0073e9SAndroid Build Coastguard Worker map_after_test = torch._C._get_operator_version_map() 79*da0073e9SAndroid Build Coastguard Worker self.assertTrue("aten::_test_serialization_subcmul" in map_after_test) 80*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(map_after_test) - len(map_before_test) == 1) 81*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_remove_entry_to_op_version_map( 82*da0073e9SAndroid Build Coastguard Worker "aten::_test_serialization_subcmul" 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker map_after_remove_test = torch._C._get_operator_version_map() 85*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 86*da0073e9SAndroid Build Coastguard Worker "aten::_test_serialization_subcmul" not in map_after_remove_test 87*da0073e9SAndroid Build Coastguard Worker ) 88*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(map_after_remove_test), len(map_before_test)) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def test_populated_test_upgrader_graph(self): 91*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 92*da0073e9SAndroid Build Coastguard Worker def f(): 93*da0073e9SAndroid Build Coastguard Worker return 0 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 96*da0073e9SAndroid Build Coastguard Worker torch.jit.save(f, buffer) 97*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 98*da0073e9SAndroid Build Coastguard Worker torch.jit.load(buffer) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker # upgrader map should have populated now 101*da0073e9SAndroid Build Coastguard Worker upgraders_size = torch._C._get_upgraders_map_size() 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())} 104*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_populate_upgraders(test_map) 105*da0073e9SAndroid Build Coastguard Worker upgraders_size_after_test = torch._C._get_upgraders_map_size() 106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(upgraders_size_after_test - upgraders_size, 2) 107*da0073e9SAndroid Build Coastguard Worker upgraders_dump = torch._C._dump_upgraders_map() 108*da0073e9SAndroid Build Coastguard Worker self.assertTrue("a" in upgraders_dump) 109*da0073e9SAndroid Build Coastguard Worker self.assertTrue("c" in upgraders_dump) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_remove_upgraders(test_map) 112*da0073e9SAndroid Build Coastguard Worker upgraders_size_after_remove_test = torch._C._get_upgraders_map_size() 113*da0073e9SAndroid Build Coastguard Worker self.assertTrue(upgraders_size_after_remove_test == upgraders_size) 114*da0073e9SAndroid Build Coastguard Worker upgraders_dump_after_remove_test = torch._C._dump_upgraders_map() 115*da0073e9SAndroid Build Coastguard Worker self.assertTrue("a" not in upgraders_dump_after_remove_test) 116*da0073e9SAndroid Build Coastguard Worker self.assertTrue("c" not in upgraders_dump_after_remove_test) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def test_aten_div_tensor_at_3(self): 119*da0073e9SAndroid Build Coastguard Worker model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt" 120*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 121*da0073e9SAndroid Build Coastguard Worker # there are 3 aten::div in this model 122*da0073e9SAndroid Build Coastguard Worker # And the upgrader for aten::div uses two 123*da0073e9SAndroid Build Coastguard Worker # div's because of if/else branch 124*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::If").run(loaded_model.graph) 125*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::div", 6).run(loaded_model.graph) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 128*da0073e9SAndroid Build Coastguard Worker torch.jit.save(loaded_model, buffer) 129*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 130*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 131*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 4) 132*da0073e9SAndroid Build Coastguard Worker loaded_model_twice = torch.jit.load(buffer) 133*da0073e9SAndroid Build Coastguard Worker # we check by its code because graph variable names 134*da0073e9SAndroid Build Coastguard Worker # can be different every time 135*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_model.code, loaded_model_twice.code) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def test_aten_full_other_variants(self): 138*da0073e9SAndroid Build Coastguard Worker def test_func(): 139*da0073e9SAndroid Build Coastguard Worker a = torch.full([4, 5, 6], 4, names=["a", "b", "c"], dtype=torch.int64) 140*da0073e9SAndroid Build Coastguard Worker return a 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker scripted_func = torch.jit.script(test_func) 143*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 144*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_func, buffer) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker current_flag_value = torch._C._get_version_calculator_flag() 147*da0073e9SAndroid Build Coastguard Worker # calculate based on old version 148*da0073e9SAndroid Build Coastguard Worker torch._C._calculate_package_version_based_on_upgraders(False) 149*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 150*da0073e9SAndroid Build Coastguard Worker loaded_func = torch.jit.load(buffer) 151*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_func) 152*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 5) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker # calculate based on new version 155*da0073e9SAndroid Build Coastguard Worker torch._C._calculate_package_version_based_on_upgraders(True) 156*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 157*da0073e9SAndroid Build Coastguard Worker loaded_func = torch.jit.load(buffer) 158*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_func) 159*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 5) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker # make sure we preserve old behaviou 162*da0073e9SAndroid Build Coastguard Worker torch._C._calculate_package_version_based_on_upgraders(current_flag_value) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker def test_aten_linspace(self): 165*da0073e9SAndroid Build Coastguard Worker model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl" 166*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 167*da0073e9SAndroid Build Coastguard Worker sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) 168*da0073e9SAndroid Build Coastguard Worker for a, b in sample_inputs: 169*da0073e9SAndroid Build Coastguard Worker output_with_step, output_without_step = loaded_model(a, b) 170*da0073e9SAndroid Build Coastguard Worker # when no step is given, should have used 100 171*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_without_step.size(dim=0) == 100) 172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_with_step.size(dim=0) == 5) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 175*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 8) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def test_aten_linspace_out(self): 178*da0073e9SAndroid Build Coastguard Worker model_path = ( 179*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl" 180*da0073e9SAndroid Build Coastguard Worker ) 181*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 182*da0073e9SAndroid Build Coastguard Worker sample_inputs = ( 183*da0073e9SAndroid Build Coastguard Worker (3, 10, torch.empty((100,), dtype=torch.int64)), 184*da0073e9SAndroid Build Coastguard Worker (-10, 10, torch.empty((100,), dtype=torch.int64)), 185*da0073e9SAndroid Build Coastguard Worker (4.0, 6.0, torch.empty((100,), dtype=torch.float64)), 186*da0073e9SAndroid Build Coastguard Worker (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)), 187*da0073e9SAndroid Build Coastguard Worker ) 188*da0073e9SAndroid Build Coastguard Worker for a, b, c in sample_inputs: 189*da0073e9SAndroid Build Coastguard Worker output = loaded_model(a, b, c) 190*da0073e9SAndroid Build Coastguard Worker # when no step is given, should have used 100 191*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.size(dim=0) == 100) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 194*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 8) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker def test_aten_logspace(self): 197*da0073e9SAndroid Build Coastguard Worker model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl" 198*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 199*da0073e9SAndroid Build Coastguard Worker sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) 200*da0073e9SAndroid Build Coastguard Worker for a, b in sample_inputs: 201*da0073e9SAndroid Build Coastguard Worker output_with_step, output_without_step = loaded_model(a, b) 202*da0073e9SAndroid Build Coastguard Worker # when no step is given, should have used 100 203*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_without_step.size(dim=0) == 100) 204*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output_with_step.size(dim=0) == 5) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 207*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 9) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker def test_aten_logspace_out(self): 210*da0073e9SAndroid Build Coastguard Worker model_path = ( 211*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl" 212*da0073e9SAndroid Build Coastguard Worker ) 213*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 214*da0073e9SAndroid Build Coastguard Worker sample_inputs = ( 215*da0073e9SAndroid Build Coastguard Worker (3, 10, torch.empty((100,), dtype=torch.int64)), 216*da0073e9SAndroid Build Coastguard Worker (-10, 10, torch.empty((100,), dtype=torch.int64)), 217*da0073e9SAndroid Build Coastguard Worker (4.0, 6.0, torch.empty((100,), dtype=torch.float64)), 218*da0073e9SAndroid Build Coastguard Worker (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)), 219*da0073e9SAndroid Build Coastguard Worker ) 220*da0073e9SAndroid Build Coastguard Worker for a, b, c in sample_inputs: 221*da0073e9SAndroid Build Coastguard Worker output = loaded_model(a, b, c) 222*da0073e9SAndroid Build Coastguard Worker # when no step is given, should have used 100 223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.size(dim=0) == 100) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 226*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 9) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker def test_aten_test_serialization(self): 229*da0073e9SAndroid Build Coastguard Worker model_path = ( 230*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt" 231*da0073e9SAndroid Build Coastguard Worker ) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker # add test version entry to the version map 234*da0073e9SAndroid Build Coastguard Worker upgrader_bumped_version = 3 235*da0073e9SAndroid Build Coastguard Worker upgrader_name = "_test_serialization_subcmul_0_2" 236*da0073e9SAndroid Build Coastguard Worker upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" 237*da0073e9SAndroid Build Coastguard Worker dummy_entry = torch._C._UpgraderEntry( 238*da0073e9SAndroid Build Coastguard Worker upgrader_bumped_version, upgrader_name, upgrader_schema 239*da0073e9SAndroid Build Coastguard Worker ) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_add_entry_to_op_version_map( 242*da0073e9SAndroid Build Coastguard Worker "aten::_test_serialization_subcmul", dummy_entry 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker # add test upgrader in the upgraders map 246*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 247*da0073e9SAndroid Build Coastguard Worker def _test_serialization_subcmul_0_2( 248*da0073e9SAndroid Build Coastguard Worker self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2 249*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 250*da0073e9SAndroid Build Coastguard Worker return other - (self * alpha) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_populate_upgraders( 253*da0073e9SAndroid Build Coastguard Worker { 254*da0073e9SAndroid Build Coastguard Worker "_test_serialization_subcmul_0_2": str( 255*da0073e9SAndroid Build Coastguard Worker _test_serialization_subcmul_0_2.graph 256*da0073e9SAndroid Build Coastguard Worker ) 257*da0073e9SAndroid Build Coastguard Worker } 258*da0073e9SAndroid Build Coastguard Worker ) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker # test if the server is able to find the test upgraders and apply to IR 261*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 262*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::mul", 2).run(loaded_model.graph) 263*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::sub", 2).run(loaded_model.graph) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 266*da0073e9SAndroid Build Coastguard Worker torch.jit.save(loaded_model, buffer) 267*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 268*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 269*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 3) 270*da0073e9SAndroid Build Coastguard Worker loaded_model_twice = torch.jit.load(buffer) 271*da0073e9SAndroid Build Coastguard Worker # we check by its' code because graph variable names 272*da0073e9SAndroid Build Coastguard Worker # can be different every time 273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_model.code, loaded_model_twice.code) 274*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_remove_entry_to_op_version_map( 275*da0073e9SAndroid Build Coastguard Worker "aten::_test_serialization_subcmul" 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker torch._C._test_only_remove_upgraders( 278*da0073e9SAndroid Build Coastguard Worker { 279*da0073e9SAndroid Build Coastguard Worker "_test_serialization_subcmul_0_2": str( 280*da0073e9SAndroid Build Coastguard Worker _test_serialization_subcmul_0_2.graph 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker } 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def test_aten_div_scalar_at_3(self): 286*da0073e9SAndroid Build Coastguard Worker model_path = ( 287*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt" 288*da0073e9SAndroid Build Coastguard Worker ) 289*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 290*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::If").run(loaded_model.graph) 291*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::div", 2).run(loaded_model.graph) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 294*da0073e9SAndroid Build Coastguard Worker torch.jit.save(loaded_model, buffer) 295*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 296*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(version, 4) 298*da0073e9SAndroid Build Coastguard Worker loaded_model_twice = torch.jit.load(buffer) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 301*da0073e9SAndroid Build Coastguard Worker loaded_model(torch.Tensor([5.0, 3.0]), 2.0), 302*da0073e9SAndroid Build Coastguard Worker loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0), 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker def test_aten_div_tensor_out_at_3(self): 306*da0073e9SAndroid Build Coastguard Worker model_path = ( 307*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt" 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 310*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::If").run(loaded_model.graph) 311*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::div", 2).run(loaded_model.graph) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 314*da0073e9SAndroid Build Coastguard Worker torch.jit.save(loaded_model, buffer) 315*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 316*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 317*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 4) 318*da0073e9SAndroid Build Coastguard Worker loaded_model_twice = torch.jit.load(buffer) 319*da0073e9SAndroid Build Coastguard Worker # we check by its' code because graph variable names 320*da0073e9SAndroid Build Coastguard Worker # can be different every time 321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_model.code, loaded_model_twice.code) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker def test_aten_full_at_4(self): 324*da0073e9SAndroid Build Coastguard Worker model_path = ( 325*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt" 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 328*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::Float", 1).run(loaded_model.graph) 329*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::full", 2).run(loaded_model.graph) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 332*da0073e9SAndroid Build Coastguard Worker torch.jit.save(loaded_model, buffer) 333*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 334*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 335*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 5) 336*da0073e9SAndroid Build Coastguard Worker loaded_model_twice = torch.jit.load(buffer) 337*da0073e9SAndroid Build Coastguard Worker # we check by its' code because graph variable names 338*da0073e9SAndroid Build Coastguard Worker # can be different every time 339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_model.code, loaded_model_twice.code) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker def test_aten_full_out_at_4(self): 342*da0073e9SAndroid Build Coastguard Worker model_path = ( 343*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt" 344*da0073e9SAndroid Build Coastguard Worker ) 345*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(model_path) 346*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::full", 5).run(loaded_model.graph) 347*da0073e9SAndroid Build Coastguard Worker version = self._load_model_version(loaded_model) 348*da0073e9SAndroid Build Coastguard Worker self.assertTrue(version == 5) 349