1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6from itertools import product as product 7from typing import Union 8 9import hypothesis.strategies as st 10from hypothesis import example, given, settings 11 12import torch 13 14 15# Make the helper files in test/ importable 16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17sys.path.append(pytorch_test_dir) 18from torch.jit.mobile import _load_for_lite_interpreter 19from torch.testing._internal.jit_utils import JitTestCase 20 21 22if __name__ == "__main__": 23 raise RuntimeError( 24 "This test file is not meant to be run directly, use:\n\n" 25 "\tpython test/test_jit.py TESTNAME\n\n" 26 "instead." 27 ) 28 29 30class TestSaveLoadForOpVersion(JitTestCase): 31 # Helper that returns the module after saving and loading 32 def _save_load_module(self, m): 33 scripted_module = torch.jit.script(m()) 34 buffer = io.BytesIO() 35 torch.jit.save(scripted_module, buffer) 36 buffer.seek(0) 37 return torch.jit.load(buffer) 38 39 def _save_load_mobile_module(self, m): 40 scripted_module = torch.jit.script(m()) 41 buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter()) 42 buffer.seek(0) 43 return _load_for_lite_interpreter(buffer) 44 45 # Helper which returns the result of a function or the exception the 46 # function threw. 47 def _try_fn(self, fn, *args, **kwargs): 48 try: 49 return fn(*args, **kwargs) 50 except Exception as e: 51 return e 52 53 def _verify_no(self, kind, m): 54 self._verify_count(kind, m, 0) 55 56 def _verify_count(self, kind, m, count): 57 node_count = sum(str(n).count(kind) for n in m.graph.nodes()) 58 self.assertEqual(node_count, count) 59 60 """ 61 Tests that verify Torchscript remaps aten::div(_) from versions 0-3 62 to call either aten::true_divide(_), if an input is a float type, 63 or truncated aten::divide(_) otherwise. 64 NOTE: currently compares against current div behavior, too, since 65 div behavior has not yet been updated. 66 """ 67 68 @settings( 69 max_examples=10, deadline=200000 70 ) # A total of 10 examples will be generated 71 @given( 72 sample_input=st.tuples( 73 st.integers(min_value=5, max_value=199), 74 st.floats(min_value=5.0, max_value=199.0), 75 ) 76 ) # Generate a pair (integer, float) 77 @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered 78 def test_versioned_div_tensor(self, sample_input): 79 def historic_div(self, other): 80 if self.is_floating_point() or other.is_floating_point(): 81 return self.true_divide(other) 82 return self.divide(other, rounding_mode="trunc") 83 84 # Tensor x Tensor 85 class MyModule(torch.nn.Module): 86 def forward(self, a, b): 87 result_0 = a / b 88 result_1 = torch.div(a, b) 89 result_2 = a.div(b) 90 91 return result_0, result_1, result_2 92 93 # Loads historic module 94 try: 95 v3_mobile_module = _load_for_lite_interpreter( 96 pytorch_test_dir 97 + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl" 98 ) 99 except Exception as e: 100 self.skipTest("Failed to load fixture!") 101 102 current_mobile_module = self._save_load_mobile_module(MyModule) 103 104 for val_a, val_b in product(sample_input, sample_input): 105 a = torch.tensor((val_a,)) 106 b = torch.tensor((val_b,)) 107 108 def _helper(m, fn): 109 m_results = self._try_fn(m, a, b) 110 fn_result = self._try_fn(fn, a, b) 111 112 if isinstance(m_results, Exception): 113 self.assertTrue(isinstance(fn_result, Exception)) 114 else: 115 for result in m_results: 116 self.assertEqual(result, fn_result) 117 118 _helper(v3_mobile_module, historic_div) 119 _helper(current_mobile_module, torch.div) 120 121 @settings( 122 max_examples=10, deadline=200000 123 ) # A total of 10 examples will be generated 124 @given( 125 sample_input=st.tuples( 126 st.integers(min_value=5, max_value=199), 127 st.floats(min_value=5.0, max_value=199.0), 128 ) 129 ) # Generate a pair (integer, float) 130 @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered 131 def test_versioned_div_tensor_inplace(self, sample_input): 132 def historic_div_(self, other): 133 if self.is_floating_point() or other.is_floating_point(): 134 return self.true_divide_(other) 135 return self.divide_(other, rounding_mode="trunc") 136 137 class MyModule(torch.nn.Module): 138 def forward(self, a, b): 139 a /= b 140 return a 141 142 try: 143 v3_mobile_module = _load_for_lite_interpreter( 144 pytorch_test_dir 145 + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl" 146 ) 147 except Exception as e: 148 self.skipTest("Failed to load fixture!") 149 150 current_mobile_module = self._save_load_mobile_module(MyModule) 151 152 for val_a, val_b in product(sample_input, sample_input): 153 a = torch.tensor((val_a,)) 154 b = torch.tensor((val_b,)) 155 156 def _helper(m, fn): 157 fn_result = self._try_fn(fn, a.clone(), b) 158 m_result = self._try_fn(m, a, b) 159 if isinstance(m_result, Exception): 160 self.assertTrue(fn_result, Exception) 161 else: 162 self.assertEqual(m_result, fn_result) 163 self.assertEqual(m_result, a) 164 165 _helper(v3_mobile_module, historic_div_) 166 167 # Recreates a since it was modified in place 168 a = torch.tensor((val_a,)) 169 _helper(current_mobile_module, torch.Tensor.div_) 170 171 @settings( 172 max_examples=10, deadline=200000 173 ) # A total of 10 examples will be generated 174 @given( 175 sample_input=st.tuples( 176 st.integers(min_value=5, max_value=199), 177 st.floats(min_value=5.0, max_value=199.0), 178 ) 179 ) # Generate a pair (integer, float) 180 @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered 181 def test_versioned_div_tensor_out(self, sample_input): 182 def historic_div_out(self, other, out): 183 if ( 184 self.is_floating_point() 185 or other.is_floating_point() 186 or out.is_floating_point() 187 ): 188 return torch.true_divide(self, other, out=out) 189 return torch.divide(self, other, out=out, rounding_mode="trunc") 190 191 class MyModule(torch.nn.Module): 192 def forward(self, a, b, out): 193 return a.div(b, out=out) 194 195 try: 196 v3_mobile_module = _load_for_lite_interpreter( 197 pytorch_test_dir 198 + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl" 199 ) 200 except Exception as e: 201 self.skipTest("Failed to load fixture!") 202 203 current_mobile_module = self._save_load_mobile_module(MyModule) 204 205 for val_a, val_b in product(sample_input, sample_input): 206 a = torch.tensor((val_a,)) 207 b = torch.tensor((val_b,)) 208 209 for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)): 210 211 def _helper(m, fn): 212 fn_result = None 213 if fn is torch.div: 214 fn_result = self._try_fn(fn, a, b, out=out.clone()) 215 else: 216 fn_result = self._try_fn(fn, a, b, out.clone()) 217 m_result = self._try_fn(m, a, b, out) 218 219 if isinstance(m_result, Exception): 220 self.assertTrue(fn_result, Exception) 221 else: 222 self.assertEqual(m_result, fn_result) 223 self.assertEqual(m_result, out) 224 225 _helper(v3_mobile_module, historic_div_out) 226 _helper(current_mobile_module, torch.div) 227 228 @settings( 229 max_examples=10, deadline=200000 230 ) # A total of 10 examples will be generated 231 @given( 232 sample_input=st.tuples( 233 st.integers(min_value=5, max_value=199), 234 st.floats(min_value=5.0, max_value=199.0), 235 ) 236 ) # Generate a pair (integer, float) 237 @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered 238 def test_versioned_div_scalar(self, sample_input): 239 def historic_div_scalar_float(self, other: float): 240 return torch.true_divide(self, other) 241 242 def historic_div_scalar_int(self, other: int): 243 if self.is_floating_point(): 244 return torch.true_divide(self, other) 245 return torch.divide(self, other, rounding_mode="trunc") 246 247 class MyModuleFloat(torch.nn.Module): 248 def forward(self, a, b: float): 249 return a / b 250 251 class MyModuleInt(torch.nn.Module): 252 def forward(self, a, b: int): 253 return a / b 254 255 try: 256 v3_mobile_module_float = _load_for_lite_interpreter( 257 pytorch_test_dir 258 + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl" 259 ) 260 v3_mobile_module_int = _load_for_lite_interpreter( 261 pytorch_test_dir 262 + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl" 263 ) 264 except Exception as e: 265 self.skipTest("Failed to load fixture!") 266 267 current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat) 268 current_mobile_module_int = self._save_load_mobile_module(MyModuleInt) 269 270 for val_a, val_b in product(sample_input, sample_input): 271 a = torch.tensor((val_a,)) 272 b = val_b 273 274 def _helper(m, fn): 275 m_result = self._try_fn(m, a, b) 276 fn_result = self._try_fn(fn, a, b) 277 278 if isinstance(m_result, Exception): 279 self.assertTrue(fn_result, Exception) 280 else: 281 self.assertEqual(m_result, fn_result) 282 283 if isinstance(b, float): 284 _helper(v3_mobile_module_float, current_mobile_module_float) 285 _helper(current_mobile_module_float, torch.div) 286 else: 287 _helper(v3_mobile_module_int, historic_div_scalar_int) 288 _helper(current_mobile_module_int, torch.div) 289 290 @settings( 291 max_examples=10, deadline=200000 292 ) # A total of 10 examples will be generated 293 @given( 294 sample_input=st.tuples( 295 st.integers(min_value=5, max_value=199), 296 st.floats(min_value=5.0, max_value=199.0), 297 ) 298 ) # Generate a pair (integer, float) 299 @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered 300 def test_versioned_div_scalar_reciprocal(self, sample_input): 301 def historic_div_scalar_float_reciprocal(self, other: float): 302 return other / self 303 304 def historic_div_scalar_int_reciprocal(self, other: int): 305 if self.is_floating_point(): 306 return other / self 307 return torch.divide(other, self, rounding_mode="trunc") 308 309 class MyModuleFloat(torch.nn.Module): 310 def forward(self, a, b: float): 311 return b / a 312 313 class MyModuleInt(torch.nn.Module): 314 def forward(self, a, b: int): 315 return b / a 316 317 try: 318 v3_mobile_module_float = _load_for_lite_interpreter( 319 pytorch_test_dir 320 + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl" 321 ) 322 v3_mobile_module_int = _load_for_lite_interpreter( 323 pytorch_test_dir 324 + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl" 325 ) 326 except Exception as e: 327 self.skipTest("Failed to load fixture!") 328 329 current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat) 330 current_mobile_module_int = self._save_load_mobile_module(MyModuleInt) 331 332 for val_a, val_b in product(sample_input, sample_input): 333 a = torch.tensor((val_a,)) 334 b = val_b 335 336 def _helper(m, fn): 337 m_result = self._try_fn(m, a, b) 338 fn_result = None 339 # Reverses argument order for torch.div 340 if fn is torch.div: 341 fn_result = self._try_fn(torch.div, b, a) 342 else: 343 fn_result = self._try_fn(fn, a, b) 344 345 if isinstance(m_result, Exception): 346 self.assertTrue(isinstance(fn_result, Exception)) 347 elif fn is torch.div or a.is_floating_point(): 348 self.assertEqual(m_result, fn_result) 349 else: 350 # Skip when fn is not torch.div and a is integral because 351 # historic_div_scalar_int performs floored division 352 pass 353 354 if isinstance(b, float): 355 _helper(v3_mobile_module_float, current_mobile_module_float) 356 _helper(current_mobile_module_float, torch.div) 357 else: 358 _helper(v3_mobile_module_int, current_mobile_module_int) 359 _helper(current_mobile_module_int, torch.div) 360 361 @settings( 362 max_examples=10, deadline=200000 363 ) # A total of 10 examples will be generated 364 @given( 365 sample_input=st.tuples( 366 st.integers(min_value=5, max_value=199), 367 st.floats(min_value=5.0, max_value=199.0), 368 ) 369 ) # Generate a pair (integer, float) 370 @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered 371 def test_versioned_div_scalar_inplace(self, sample_input): 372 def historic_div_scalar_float_inplace(self, other: float): 373 return self.true_divide_(other) 374 375 def historic_div_scalar_int_inplace(self, other: int): 376 if self.is_floating_point(): 377 return self.true_divide_(other) 378 379 return self.divide_(other, rounding_mode="trunc") 380 381 class MyModuleFloat(torch.nn.Module): 382 def forward(self, a, b: float): 383 a /= b 384 return a 385 386 class MyModuleInt(torch.nn.Module): 387 def forward(self, a, b: int): 388 a /= b 389 return a 390 391 try: 392 v3_mobile_module_float = _load_for_lite_interpreter( 393 pytorch_test_dir 394 + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl" 395 ) 396 v3_mobile_module_int = _load_for_lite_interpreter( 397 pytorch_test_dir 398 + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl" 399 ) 400 except Exception as e: 401 self.skipTest("Failed to load fixture!") 402 403 current_mobile_module_float = self._save_load_module(MyModuleFloat) 404 current_mobile_module_int = self._save_load_module(MyModuleInt) 405 406 for val_a, val_b in product(sample_input, sample_input): 407 a = torch.tensor((val_a,)) 408 b = val_b 409 410 def _helper(m, fn): 411 m_result = self._try_fn(m, a, b) 412 fn_result = self._try_fn(fn, a, b) 413 414 if isinstance(m_result, Exception): 415 self.assertTrue(fn_result, Exception) 416 else: 417 self.assertEqual(m_result, fn_result) 418 419 if isinstance(b, float): 420 _helper(current_mobile_module_float, torch.Tensor.div_) 421 else: 422 _helper(current_mobile_module_int, torch.Tensor.div_) 423 424 # NOTE: Scalar division was already true division in op version 3, 425 # so this test verifies the behavior is unchanged. 426 def test_versioned_div_scalar_scalar(self): 427 class MyModule(torch.nn.Module): 428 def forward(self, a: float, b: int, c: float, d: int): 429 result_0 = a / b 430 result_1 = a / c 431 result_2 = b / c 432 result_3 = b / d 433 return (result_0, result_1, result_2, result_3) 434 435 try: 436 v3_mobile_module = _load_for_lite_interpreter( 437 pytorch_test_dir 438 + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl" 439 ) 440 except Exception as e: 441 self.skipTest("Failed to load fixture!") 442 443 current_mobile_module = self._save_load_mobile_module(MyModule) 444 445 def _helper(m, fn): 446 vals = (5.0, 3, 2.0, 7) 447 m_result = m(*vals) 448 fn_result = fn(*vals) 449 for mr, hr in zip(m_result, fn_result): 450 self.assertEqual(mr, hr) 451 452 _helper(v3_mobile_module, current_mobile_module) 453 454 def test_versioned_linspace(self): 455 class Module(torch.nn.Module): 456 def forward( 457 self, a: Union[int, float, complex], b: Union[int, float, complex] 458 ): 459 c = torch.linspace(a, b, steps=5) 460 d = torch.linspace(a, b, steps=100) 461 return c, d 462 463 scripted_module = torch.jit.load( 464 pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl" 465 ) 466 467 buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter()) 468 buffer.seek(0) 469 v7_mobile_module = _load_for_lite_interpreter(buffer) 470 471 current_mobile_module = self._save_load_mobile_module(Module) 472 473 sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) 474 for a, b in sample_inputs: 475 (output_with_step, output_without_step) = v7_mobile_module(a, b) 476 (current_with_step, current_without_step) = current_mobile_module(a, b) 477 # when no step is given, should have used 100 478 self.assertTrue(output_without_step.size(dim=0) == 100) 479 self.assertTrue(output_with_step.size(dim=0) == 5) 480 # outputs should be equal to the newest version 481 self.assertEqual(output_with_step, current_with_step) 482 self.assertEqual(output_without_step, current_without_step) 483 484 def test_versioned_linspace_out(self): 485 class Module(torch.nn.Module): 486 def forward( 487 self, 488 a: Union[int, float, complex], 489 b: Union[int, float, complex], 490 out: torch.Tensor, 491 ): 492 return torch.linspace(a, b, steps=100, out=out) 493 494 model_path = ( 495 pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl" 496 ) 497 loaded_model = torch.jit.load(model_path) 498 buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter()) 499 buffer.seek(0) 500 v7_mobile_module = _load_for_lite_interpreter(buffer) 501 current_mobile_module = self._save_load_mobile_module(Module) 502 503 sample_inputs = ( 504 ( 505 3, 506 10, 507 torch.empty((100,), dtype=torch.int64), 508 torch.empty((100,), dtype=torch.int64), 509 ), 510 ( 511 -10, 512 10, 513 torch.empty((100,), dtype=torch.int64), 514 torch.empty((100,), dtype=torch.int64), 515 ), 516 ( 517 4.0, 518 6.0, 519 torch.empty((100,), dtype=torch.float64), 520 torch.empty((100,), dtype=torch.float64), 521 ), 522 ( 523 3 + 4j, 524 4 + 5j, 525 torch.empty((100,), dtype=torch.complex64), 526 torch.empty((100,), dtype=torch.complex64), 527 ), 528 ) 529 for start, end, out_for_old, out_for_new in sample_inputs: 530 output = v7_mobile_module(start, end, out_for_old) 531 output_current = current_mobile_module(start, end, out_for_new) 532 # when no step is given, should have used 100 533 self.assertTrue(output.size(dim=0) == 100) 534 # "Upgraded" model should match the new version output 535 self.assertEqual(output, output_current) 536 537 def test_versioned_logspace(self): 538 class Module(torch.nn.Module): 539 def forward( 540 self, a: Union[int, float, complex], b: Union[int, float, complex] 541 ): 542 c = torch.logspace(a, b, steps=5) 543 d = torch.logspace(a, b, steps=100) 544 return c, d 545 546 scripted_module = torch.jit.load( 547 pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl" 548 ) 549 550 buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter()) 551 buffer.seek(0) 552 v8_mobile_module = _load_for_lite_interpreter(buffer) 553 554 current_mobile_module = self._save_load_mobile_module(Module) 555 556 sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) 557 for a, b in sample_inputs: 558 (output_with_step, output_without_step) = v8_mobile_module(a, b) 559 (current_with_step, current_without_step) = current_mobile_module(a, b) 560 # when no step is given, should have used 100 561 self.assertTrue(output_without_step.size(dim=0) == 100) 562 self.assertTrue(output_with_step.size(dim=0) == 5) 563 # outputs should be equal to the newest version 564 self.assertEqual(output_with_step, current_with_step) 565 self.assertEqual(output_without_step, current_without_step) 566 567 def test_versioned_logspace_out(self): 568 class Module(torch.nn.Module): 569 def forward( 570 self, 571 a: Union[int, float, complex], 572 b: Union[int, float, complex], 573 out: torch.Tensor, 574 ): 575 return torch.logspace(a, b, steps=100, out=out) 576 577 model_path = ( 578 pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl" 579 ) 580 loaded_model = torch.jit.load(model_path) 581 buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter()) 582 buffer.seek(0) 583 v8_mobile_module = _load_for_lite_interpreter(buffer) 584 current_mobile_module = self._save_load_mobile_module(Module) 585 586 sample_inputs = ( 587 ( 588 3, 589 10, 590 torch.empty((100,), dtype=torch.int64), 591 torch.empty((100,), dtype=torch.int64), 592 ), 593 ( 594 -10, 595 10, 596 torch.empty((100,), dtype=torch.int64), 597 torch.empty((100,), dtype=torch.int64), 598 ), 599 ( 600 4.0, 601 6.0, 602 torch.empty((100,), dtype=torch.float64), 603 torch.empty((100,), dtype=torch.float64), 604 ), 605 ( 606 3 + 4j, 607 4 + 5j, 608 torch.empty((100,), dtype=torch.complex64), 609 torch.empty((100,), dtype=torch.complex64), 610 ), 611 ) 612 for start, end, out_for_old, out_for_new in sample_inputs: 613 output = v8_mobile_module(start, end, out_for_old) 614 output_current = current_mobile_module(start, end, out_for_new) 615 # when no step is given, should have used 100 616 self.assertTrue(output.size(dim=0) == 100) 617 # "Upgraded" model should match the new version output 618 self.assertEqual(output, output_current) 619