1# Owner(s): ["oncall: export"] 2 3import unittest 4from collections import OrderedDict 5from typing import Any, Dict, List, Optional, Tuple 6 7import torch 8import torch.utils._pytree as pytree 9from torch._dynamo.test_case import TestCase 10from torch._export.converter import TS2EPConverter 11from torch.export import ExportedProgram 12from torch.testing._internal.common_quantized import override_quantized_engine 13from torch.testing._internal.common_utils import IS_WINDOWS, run_tests 14from torch.testing._internal.torchbind_impls import ( 15 _empty_tensor_queue, 16 init_torchbind_implementations, 17) 18 19 20requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") 21 22 23class TestConverter(TestCase): 24 def setUp(self): 25 init_torchbind_implementations() 26 27 @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") 28 class FakeTensorQueue: 29 def __init__(self, queue): 30 self.queue = queue 31 32 @classmethod 33 def __obj_unflatten__(cls, flattened_ctx): 34 return cls(**dict(flattened_ctx)) 35 36 def push(self, x): 37 self.queue.append(x) 38 39 def pop(self): 40 if self.is_empty(): 41 return torch.empty([]) 42 return self.queue.pop(0) 43 44 def size(self): 45 return len(self.queue) 46 47 def is_empty(self): 48 return len(self.queue) == 0 49 50 def float_size(self): 51 return float(len(self.queue)) 52 53 self.torch_bind_ops = [ 54 torch.ops._TorchScriptTesting.queue_pop, 55 torch.ops._TorchScriptTesting.queue_push, 56 torch.ops._TorchScriptTesting.queue_size, 57 ] 58 59 def tearDown(self): 60 torch._library.fake_class_registry.deregister_fake_class( 61 "_TorchScriptTesting::_TensorQueue" 62 ) 63 64 def _check_equal_ts_ep_converter( 65 self, 66 M, 67 inp, 68 option: Optional[List[str]] = None, 69 check_persistent=False, 70 lifted_tensor_constants=None, 71 ) -> List[ExportedProgram]: 72 # By default, it tests both jit.trace and jit.script. 73 if option is None: 74 option = ["trace", "script"] 75 76 if check_persistent: 77 num_iterations = 10 78 else: 79 num_iterations = 1 80 81 ep_list = [] 82 for opt in option: 83 if opt == "script": 84 # Separate two models for testing non-functional effects 85 if check_persistent: 86 original_ts_model = torch.jit.script(M()) 87 ts_model = torch.jit.script(M()) 88 eager_model = M() 89 else: 90 original_ts_model = torch.jit.script(M) 91 ts_model = torch.jit.script(M) 92 eager_model = M 93 elif opt == "trace": 94 if check_persistent: 95 original_ts_model = torch.jit.trace(M(), inp) 96 ts_model = torch.jit.trace(M(), inp) 97 eager_model = M() 98 else: 99 original_ts_model = torch.jit.trace(M, inp) 100 ts_model = torch.jit.trace(M, inp) 101 eager_model = M 102 else: 103 raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}") 104 105 converter = TS2EPConverter(ts_model, inp) 106 ep = converter.convert() 107 ep_list.append(ep) 108 109 for _ in range(num_iterations): 110 orig_out, _ = pytree.tree_flatten(original_ts_model(*inp)) 111 ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) 112 113 # Check module. 114 if isinstance(eager_model, torch.nn.Module): 115 expected_state_dict = OrderedDict() 116 expected_state_dict.update(ts_model.state_dict()) 117 if lifted_tensor_constants: 118 expected_state_dict.update(lifted_tensor_constants) 119 self.assertEqual( 120 ep.state_dict.keys(), 121 expected_state_dict.keys(), 122 ) 123 124 # Check results 125 self._check_tensor_list_equal(ep_out, orig_out) 126 return ep_list 127 128 def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]): 129 self.assertEqual(len(xs), len(ys)) 130 for x, y in zip(xs, ys): 131 if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): 132 self.assertEqual(x.shape, y.shape) 133 self.assertTrue(torch.allclose(x, y)) 134 else: 135 self.assertEqual(type(x), type(y)) 136 self.assertEqual(x, y) 137 138 def test_ts2ep_converter_basic(self): 139 class MSingle(torch.nn.Module): 140 def forward(self, x, y): 141 return x + y 142 143 class MMulti(torch.nn.Module): 144 def forward(self, x, y): 145 x = x.cos() + 1 146 y = y.sin() - 1 147 return x, y 148 149 inp = (torch.ones(1, 3), torch.ones(1, 3)) 150 self._check_equal_ts_ep_converter(MSingle(), inp) 151 self._check_equal_ts_ep_converter(MMulti(), inp) 152 153 def test_ts2ep_converter_container_output(self): 154 # Output is a List. 155 class MOutputList(torch.nn.Module): 156 def forward(self, x: torch.Tensor, y: torch.Tensor): 157 a = x * x 158 b = y + y 159 return [a, b] 160 161 # Output is a Tuple. 162 class MOutputTuple(torch.nn.Module): 163 def forward(self, x: torch.Tensor, y: torch.Tensor): 164 a = x * x 165 b = y + y 166 return (a, b) 167 168 # Output is a Dict. 169 class MOutputDict(torch.nn.Module): 170 def forward(self, x: torch.Tensor, y: torch.Tensor): 171 a = x * x 172 b = y + y 173 return {"data": {"mul": a, "add": b}} 174 175 inp = (torch.tensor(4), torch.tensor(4)) 176 177 # Traced function must use immutable structure as output. 178 self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"]) 179 self._check_equal_ts_ep_converter(MOutputTuple(), inp) 180 self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"]) 181 182 def test_aten_dim(self): 183 class Module(torch.nn.Module): 184 def forward(self, x): 185 num_dim = x.dim() 186 return torch.ones(num_dim) 187 188 inp = (torch.ones(1, 3),) 189 self._check_equal_ts_ep_converter(Module(), inp) 190 191 def test_aten_len(self): 192 class Module(torch.nn.Module): 193 def forward(self, x: torch.Tensor): 194 length = len(x) 195 return torch.ones(length) 196 197 # aten::len.Tensor 198 inp = (torch.ones(2, 3),) 199 self._check_equal_ts_ep_converter(Module(), inp) 200 201 class Module(torch.nn.Module): 202 def forward(self, x: List[int]): 203 length = len(x) 204 return torch.ones(length) 205 206 # aten::len.t 207 inp = ([1, 2, 3],) 208 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 209 210 class Module(torch.nn.Module): 211 def forward(self, x: Dict[int, str]): 212 length = len(x) 213 return torch.ones(length) 214 215 # aten::len.Dict_int 216 inp = ({1: "a", 2: "b", 3: "c"},) 217 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 218 219 class Module(torch.nn.Module): 220 def forward(self, x: Dict[bool, str]): 221 length = len(x) 222 return torch.ones(length) 223 224 # aten::len.Dict_bool 225 inp = ({True: "a", False: "b"},) 226 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 227 228 class Module(torch.nn.Module): 229 def forward(self, x: Dict[float, str]): 230 length = len(x) 231 return torch.ones(length) 232 233 # aten::len.Dict_float 234 inp = ({1.2: "a", 3.4: "b"},) 235 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 236 237 class Module(torch.nn.Module): 238 def forward(self, x: Dict[torch.Tensor, str]): 239 length = len(x) 240 return torch.ones(length) 241 242 # aten::len.Dict_Tensor 243 inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},) 244 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 245 246 # aten::len.str and aten::len.Dict_str are not supported 247 # since torch._C._jit_flatten does not support str 248 # inp = ("abcdefg",) 249 # self._check_equal_ts_ep_converter(Module(), inp) 250 # inp = ({"a": 1, "b": 2},) 251 # self._check_equal_ts_ep_converter(Module(), inp) 252 253 def test_aten_add_t(self): 254 # python list append 255 class Module(torch.nn.Module): 256 def forward(self, x: List[torch.Tensor]): 257 out = [] 258 out = out + x 259 a = torch.cat(out) 260 out = out + x 261 b = torch.cat(out) 262 return a, b 263 264 inp = ([torch.ones(2, 3), torch.ones(2, 3)],) 265 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 266 267 def test_aten_to_dtype_with_mutating_storage(self): 268 class Module(torch.nn.Module): 269 def forward(self, x: torch.Tensor, y: torch.Tensor): 270 x = x.to(y.dtype) 271 torch.ops.aten.index_put_(x, [torch.tensor([0])], y) 272 return x 273 274 inp = (torch.ones(2, 3), torch.tensor([0, 0, 0])) 275 self._check_equal_ts_ep_converter(Module(), inp) 276 277 def test_prim_min(self): 278 class Module(torch.nn.Module): 279 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 280 x_len = len(x) 281 y_len = len(y) 282 283 # prim::min.int 284 len_int = min(x_len, y_len) 285 286 # prim::min.float 287 len_float = int(min(x_len * 2.0, y_len * 2.0)) 288 289 # prim::min.self_int 290 len_self_int = min([x_len, y_len]) 291 292 # prim::min.self_float 293 len_self_float = int(min([x_len * 2.0, y_len * 2.0])) 294 295 # prim::min.float_int 296 len_float_int = int(min(x_len * 2.0, y_len)) 297 298 # prim::min.int_float 299 len_int_float = int(min(x_len, y_len * 2.0)) 300 301 return torch.ones( 302 len_int 303 + len_float 304 + len_self_int 305 + len_self_float 306 + len_float_int 307 + len_int_float 308 ) 309 310 inp = (torch.randn(10, 2), torch.randn(5)) 311 self._check_equal_ts_ep_converter(Module(), inp) 312 313 def test_prim_max(self): 314 class Module(torch.nn.Module): 315 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 316 x_len = len(x) 317 y_len = len(y) 318 319 # prim::max.int 320 len_int = max(x_len, y_len) 321 322 # prim::max.float 323 len_float = int(max(x_len * 2.0, y_len * 2.0)) 324 325 # prim::max.self_int 326 len_self_int = max([x_len, y_len]) 327 328 # prim::max.self_float 329 len_self_float = int(max([x_len * 2.0, y_len * 2.0])) 330 331 # prim::max.float_int 332 len_float_int = int(max(x_len * 2.0, y_len)) 333 334 # prim::max.int_float 335 len_int_float = int(max(x_len, y_len * 2.0)) 336 337 return torch.ones( 338 len_int 339 + len_float 340 + len_self_int 341 + len_self_float 342 + len_float_int 343 + len_int_float 344 ) 345 346 inp = (torch.randn(10, 2), torch.randn(5)) 347 self._check_equal_ts_ep_converter(Module(), inp) 348 349 def test_aten___getitem___list(self): 350 class Module(torch.nn.Module): 351 def forward(self, x): 352 y = torch.split(x, 2) 353 return y[0] 354 355 inp = (torch.rand((3, 2)),) 356 self._check_equal_ts_ep_converter(Module(), inp) 357 358 def test_aten___getitem___dict(self): 359 class Module(torch.nn.Module): 360 def forward(self, x): 361 y = torch.split(x, 2) 362 d_int = {0: y[0], 1: y[1]} 363 d_str = {"0": y[0], "1": y[1]} 364 d_bool = {True: y[0], False: y[1]} 365 d_float = {0.1: y[0], 2.3: y[1]} 366 return d_int[0], d_str["0"], d_bool[True], d_float[0.1] 367 368 inp = (torch.rand((3, 2)),) 369 self._check_equal_ts_ep_converter(Module(), inp) 370 371 def test_prim_device(self): 372 class Module(torch.nn.Module): 373 def forward(self, x): 374 device = x.device 375 return torch.ones(2, 3, device=device) 376 377 inp = (torch.rand(3, 4),) 378 self._check_equal_ts_ep_converter(Module(), inp) 379 380 @requires_cuda 381 def test_prim_device_cuda(self): 382 class Module(torch.nn.Module): 383 def forward(self, x): 384 device = x.device 385 return torch.ones(2, 3, device=device) 386 387 inp = (torch.rand((3, 4), device="cuda:0"),) 388 self._check_equal_ts_ep_converter(Module(), inp) 389 390 def test_prim_dtype(self): 391 class Module(torch.nn.Module): 392 def forward(self, x): 393 dtype = x.dtype 394 return torch.ones(2, 3, dtype=dtype) 395 396 for dtype in [ 397 torch.float32, 398 torch.double, 399 ]: 400 inp = (torch.rand((3, 4), dtype=dtype),) 401 self._check_equal_ts_ep_converter(Module(), inp) 402 403 for dtype in [ 404 torch.uint8, 405 torch.int8, 406 torch.int32, 407 ]: 408 inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),) 409 self._check_equal_ts_ep_converter(Module(), inp) 410 411 def test_convert_if_basic(self): 412 class M(torch.nn.Module): 413 def forward(self, x: torch.Tensor, y: torch.Tensor): 414 if x: 415 return y * y 416 else: 417 return y + y 418 419 inp = (torch.tensor(True), torch.tensor(4)) 420 ep_list = self._check_equal_ts_ep_converter(M(), inp) 421 422 for ep in ep_list[1:]: 423 torch.testing.assert_close( 424 ep.module()(torch.tensor(False), torch.tensor(4)), 425 M()(torch.tensor(False), torch.tensor(4)), 426 ) 427 428 def test_convert_if_tuple_out(self): 429 class M(torch.nn.Module): 430 def true_fn(self, y, z): 431 return (z * z, z + z) 432 433 def false_fn(self, y, z): 434 return (y * y * y, y + y) 435 436 def forward(self, x: torch.Tensor, y: torch.Tensor): 437 z = y * y 438 439 if x: 440 res = self.true_fn(y, z) 441 else: 442 res = self.false_fn(y, z) 443 444 return res[0] + res[1] 445 446 inp = (torch.tensor(True), torch.tensor(4)) 447 ep_list = self._check_equal_ts_ep_converter(M(), inp) 448 449 for ep in ep_list[1:]: 450 torch.testing.assert_close( 451 ep.module()(torch.tensor(False), torch.tensor(4)), 452 M()(torch.tensor(False), torch.tensor(4)), 453 ) 454 455 def test_convert_if_multiple_out(self): 456 class M(torch.nn.Module): 457 def true_fn(self, y, z): 458 return z * z 459 460 def false_fn(self, y, z): 461 return y * y * y 462 463 def forward(self, x: torch.Tensor, y: torch.Tensor): 464 z = y * y 465 466 if x: 467 res1 = self.true_fn(y, z) 468 res2 = y 469 else: 470 res1 = z 471 res2 = self.false_fn(y, z) 472 473 return res1 + res2 474 475 inp = (torch.tensor(True), torch.tensor(4)) 476 ep_list = self._check_equal_ts_ep_converter(M(), inp) 477 478 for ep in ep_list[1:]: 479 torch.testing.assert_close( 480 ep.module()(torch.tensor(False), torch.tensor(4)), 481 M()(torch.tensor(False), torch.tensor(4)), 482 ) 483 484 def test_profiler__record_function(self): 485 class Module(torch.nn.Module): 486 def forward(self, x: torch.Tensor) -> torch.Tensor: 487 handle = torch.ops.profiler._record_function_enter_new("foo", None) 488 y = x * 2 + 4 489 torch.ops.profiler._record_function_exit(handle) 490 return y 491 492 x = torch.randn(10, 10) 493 self._check_equal_ts_ep_converter(Module(), (x,)) 494 495 def test_aten_floordiv(self): 496 class Module(torch.nn.Module): 497 def forward(self, x: torch.Tensor) -> torch.Tensor: 498 return x // 2 499 500 x = torch.randn(10, 10) 501 self._check_equal_ts_ep_converter(Module(), (x,)) 502 503 def test_aten___is__(self): 504 class Module(torch.nn.Module): 505 def forward( 506 self, x: torch.Tensor, y: torch.Tensor 507 ) -> Tuple[bool, torch.Tensor]: 508 z = x + 1 509 return x is y, z 510 511 # Traced function must return output that has tensors. 512 inp = (torch.randn(10, 10), torch.rand(10, 10)) 513 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 514 515 def test_aten___isnot__(self): 516 class Module(torch.nn.Module): 517 def forward( 518 self, x: torch.Tensor, y: torch.Tensor 519 ) -> Tuple[bool, torch.Tensor]: 520 z = x + 1 521 return x is not y, z 522 523 # Traced function must return output that has tensors. 524 inp = (torch.randn(10, 10), torch.rand(10, 10)) 525 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 526 527 def test_aten___not__(self): 528 class Module(torch.nn.Module): 529 def forward( 530 self, x: torch.Tensor, y: torch.Tensor 531 ) -> Tuple[bool, torch.Tensor]: 532 z = x + 1 533 return not (x is not y), z 534 535 # Traced function must return output that has tensors. 536 inp = (torch.randn(10, 10), torch.rand(10, 10)) 537 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 538 539 def test_ts2ep_converter_unpack(self): 540 class MUnpackList(torch.nn.Module): 541 def forward(self, x): 542 x, y = torch.split(x, 2) 543 return x + y 544 545 class MUnpackTuple(torch.nn.Module): 546 def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]): 547 x, y = x_tuple 548 x = x.cos() 549 return x + y 550 551 inp = (torch.ones(4),) 552 self._check_equal_ts_ep_converter(MUnpackList(), inp) 553 inp = ((torch.zeros(1, 4), torch.ones(1, 4)),) 554 self._check_equal_ts_ep_converter(MUnpackTuple(), inp) 555 556 @unittest.skipIf( 557 IS_WINDOWS, 558 "torch.cond doesn't go through torch.compile on windows" 559 "causing output not normalized as list", 560 ) 561 def test_convert_retrace_nested_scripted_modules(self): 562 class Wrapper(torch.nn.Module): 563 def __init__(self, mod) -> None: 564 super().__init__() 565 self.mod = mod 566 567 def forward(self, x, y): 568 return self.mod(x, y) 569 570 class LinearM(torch.nn.Module): 571 def __init__(self, dim: int) -> None: 572 super().__init__() 573 self.linear = torch.nn.Linear(dim, dim) 574 575 def forward(self, x, y): 576 return self.linear(y) 577 578 class M(torch.nn.Module): 579 def __init__(self, dim: int) -> None: 580 super().__init__() 581 m = LinearM(dim) 582 m = torch.jit.script(m) 583 self.mod1 = m 584 self.mod2 = Wrapper(m) 585 586 def forward(self, x: torch.Tensor, y: torch.Tensor): 587 if x: 588 return -self.mod1(x, y) - self.mod2(x, y) 589 else: 590 return -self.mod1(x, y) + self.mod2(x, y) 591 592 class NestedM(torch.nn.Module): 593 def __init__(self, dim: int) -> None: 594 super().__init__() 595 m = M(dim) 596 m = torch.jit.script(m) 597 self.mod1 = m 598 self.mod2 = Wrapper(m) 599 600 def forward(self, x: torch.Tensor, y: torch.Tensor): 601 if x: 602 return self.mod1(x, y) + self.mod2(x, y) 603 else: 604 return self.mod1(x, y) - self.mod2(x, y) 605 606 inp = ( 607 torch.tensor(True), 608 torch.randn([3, 3]), 609 ) 610 self._check_equal_ts_ep_converter(NestedM(3), inp) 611 612 def test_convert_nn_module_with_nested_param(self): 613 class M(torch.nn.Module): 614 def __init__(self, dim: int) -> None: 615 super().__init__() 616 self.linear = torch.nn.Linear(dim, dim) 617 618 def forward(self, x: torch.Tensor): 619 return self.linear(x) 620 621 class NestedM(torch.nn.Module): 622 def __init__(self, dim: int) -> None: 623 super().__init__() 624 self.linear = torch.nn.Linear(dim, dim) 625 self.m = M(dim) 626 627 def forward(self, x: torch.Tensor): 628 return self.linear(self.m(x)) 629 630 class SuperNestedM(torch.nn.Module): 631 def __init__(self, dim: int) -> None: 632 super().__init__() 633 self.linear = torch.nn.Linear(dim, dim) 634 self.m = NestedM(dim) 635 636 def forward(self, x: torch.Tensor): 637 return self.linear(self.m(x)) 638 639 inp = (torch.ones(3),) 640 orig_m = NestedM(3) 641 self._check_equal_ts_ep_converter(orig_m, inp) 642 orig_m = SuperNestedM(3) 643 self._check_equal_ts_ep_converter(orig_m, inp) 644 645 def test_convert_nn_module_with_nested_buffer(self): 646 class M(torch.nn.Module): 647 def __init__(self) -> None: 648 super().__init__() 649 self.w = torch.nn.Buffer(torch.randn(1)) 650 651 def forward(self, x: torch.Tensor): 652 return self.w + x 653 654 class NestedM(torch.nn.Module): 655 def __init__(self) -> None: 656 super().__init__() 657 self.m = M() 658 self.w = torch.nn.Buffer(torch.randn(1)) 659 660 def forward(self, x: torch.Tensor): 661 return self.w + self.m(x) 662 663 class SuperNestedM(torch.nn.Module): 664 def __init__(self) -> None: 665 super().__init__() 666 self.m = NestedM() 667 self.w = torch.nn.Buffer(torch.randn(1)) 668 669 def forward(self, x: torch.Tensor): 670 return self.w + self.m(x) 671 672 inp = (torch.ones(1),) 673 orig_m = NestedM() 674 self._check_equal_ts_ep_converter(orig_m, inp) 675 orig_m = SuperNestedM() 676 self._check_equal_ts_ep_converter(orig_m, inp) 677 678 def test_convert_nn_module_with_nested_if_and_buffer(self): 679 class M(torch.nn.Module): 680 def __init__(self) -> None: 681 super().__init__() 682 self.w = torch.nn.Buffer(torch.randn(1)) 683 self.count = 1 684 685 def forward(self, x: torch.Tensor): 686 return self.w + x + self.count 687 688 class NestedM(torch.nn.Module): 689 def __init__(self) -> None: 690 super().__init__() 691 self.m1 = M() 692 self.m2 = M() 693 self.w = torch.nn.Buffer(torch.randn(1)) 694 695 def forward(self, x: torch.Tensor): 696 if torch.sum(x) > 1: 697 return self.w + self.m1(x) 698 else: 699 return self.w + self.m2(x) 700 701 # Super nested, parameters neeed to lifted 702 # multiple times. 703 class SuperNestedM(torch.nn.Module): 704 def __init__(self) -> None: 705 super().__init__() 706 self.m1 = NestedM() 707 self.m2 = NestedM() 708 self.w = torch.nn.Buffer(torch.randn(1)) 709 710 def forward(self, x: torch.Tensor): 711 if torch.max(x) > 1: 712 return self.w + self.m1(x) 713 else: 714 return self.w + self.m2(x) 715 716 # Super nested module testing. 717 inp = (torch.ones(1),) 718 orig_m = SuperNestedM() 719 ep_list = self._check_equal_ts_ep_converter(orig_m, inp) 720 721 t = inp[0] 722 t -= 1 723 for ep in ep_list: 724 torch.testing.assert_close( 725 ep.module()(*inp), 726 orig_m(*inp), 727 ) 728 729 @unittest.skipIf( 730 IS_WINDOWS, 731 "torch.cond doesn't go through torch.compile on windows" 732 "causing output not normalized as list", 733 ) 734 def test_convert_nn_module_with_nested_if_and_param(self): 735 class M(torch.nn.Module): 736 def __init__(self, dim: int) -> None: 737 super().__init__() 738 self.linear = torch.nn.Linear(dim, dim) 739 740 def forward(self, x: torch.Tensor): 741 return self.linear(x) 742 743 class NestedM(torch.nn.Module): 744 def __init__(self, dim: int) -> None: 745 super().__init__() 746 self.m1 = M(dim) 747 self.m2 = M(dim) 748 self.linear = torch.nn.Linear(dim, dim) 749 750 def forward(self, x: torch.Tensor): 751 if torch.sum(x) > 1: 752 return self.linear(self.m1(x)) 753 else: 754 return self.linear(self.m2(x)) 755 756 # Super nested, parameters neeed to lifted 757 # multiple times. 758 class SuperNestedM1(torch.nn.Module): 759 def __init__(self, dim: int) -> None: 760 super().__init__() 761 self.m1 = NestedM(dim) 762 self.m2 = NestedM(dim) 763 self.linear = torch.nn.Linear(dim, dim) 764 765 def forward(self, x: torch.Tensor): 766 if torch.max(x) > 1: 767 return self.linear(self.m1(x)) 768 else: 769 return self.linear(self.m2(x)) 770 771 # Super nested, even the input needs to be 772 # lifted recursively due to value propogation optimiztaion. 773 class SuperNestedM2(torch.nn.Module): 774 def __init__(self, dim: int) -> None: 775 super().__init__() 776 self.m1 = NestedM(dim) 777 self.m2 = NestedM(dim) 778 self.linear = torch.nn.Linear(dim, dim) 779 780 def forward(self, x: torch.Tensor): 781 if torch.sum(x) > 1: 782 return self.linear(self.m1(x)) 783 else: 784 return self.linear(self.m2(x)) 785 786 # Basic module testing. 787 inp = (torch.ones(3),) 788 orig_m = M(3) 789 ep_list = self._check_equal_ts_ep_converter(orig_m, inp) 790 791 t = inp[0] 792 t -= 0.8 793 for ep in ep_list[1:]: 794 torch.testing.assert_close( 795 ep.module()(*inp), 796 orig_m(*inp), 797 ) 798 799 # Nested module testing. 800 inp = (torch.ones(3),) 801 orig_m = NestedM(3) 802 ep_list = self._check_equal_ts_ep_converter(orig_m, inp) 803 804 t = inp[0] 805 t -= 0.8 806 # Skip jit.traced because it specializes on one path. 807 for ep in ep_list[1:]: 808 torch.testing.assert_close( 809 ep.module()(*inp), 810 orig_m(*inp), 811 ) 812 813 # Super nested module testing. 814 inp = (torch.ones(3),) 815 orig_m = SuperNestedM1(3) 816 ep_list = self._check_equal_ts_ep_converter(orig_m, inp) 817 818 t = inp[0] 819 t -= 0.8 820 # Skip jit.traced because it specializes on one path. 821 for ep in ep_list[1:]: 822 torch.testing.assert_close( 823 ep.module()(*inp), 824 orig_m(*inp), 825 ) 826 827 # Super nested module testing. 828 inp = (torch.ones(3),) 829 orig_m = SuperNestedM2(3) 830 ep_list = self._check_equal_ts_ep_converter(orig_m, inp) 831 832 t = inp[0] 833 t -= 0.8 834 # Skip jit.traced because it specializes on one path. 835 for ep in ep_list[1:]: 836 torch.testing.assert_close( 837 ep.module()(*inp), 838 orig_m(*inp), 839 ) 840 841 def test_ts2ep_converter_contains(self): 842 class MIn(torch.nn.Module): 843 def forward(self, x: torch.Tensor): 844 return x.dtype in [torch.float32, torch.float64] 845 846 class MNotIn(torch.nn.Module): 847 def forward(self, x: torch.Tensor): 848 return x.dtype in [torch.int8] 849 850 class MTensorIn(torch.nn.Module): 851 def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]): 852 return x in x_dict 853 854 # Traced function must return output that has tensors. 855 inp = (torch.tensor(4),) 856 self._check_equal_ts_ep_converter(MIn(), inp, ["script"]) 857 self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"]) 858 859 # TODO: update test to use reference for in. 860 inp = (torch.tensor(4), {torch.tensor(4): "foo"}) 861 self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"]) 862 inp = (torch.tensor(1), {torch.tensor(4): "foo"}) 863 self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"]) 864 865 def test_ts2ep_converter_custom_op(self): 866 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 867 torch._dynamo.config.capture_scalar_outputs = True 868 torch._dynamo.config.capture_dynamic_output_shape_ops = True 869 870 torch.library.define( 871 "mylib::foo", 872 "(Tensor x) -> Tensor", 873 lib=lib, 874 ) 875 876 # PyTorch custorm op implementation 877 @torch.library.impl( 878 "mylib::foo", 879 "CompositeExplicitAutograd", 880 lib=lib, 881 ) 882 def foo_impl(x): 883 return x + x 884 885 # Meta function of the custom op. 886 @torch.library.impl_abstract( 887 "mylib::foo", 888 lib=lib, 889 ) 890 def foo_meta(x): 891 return x + x 892 893 class M(torch.nn.Module): 894 def forward(self, x): 895 return torch.ops.mylib.foo(x) 896 897 inp = (torch.randn(3, 3),) 898 m = M() 899 self._check_equal_ts_ep_converter(m, inp) 900 901 def test_convert_func_without_param(self): 902 def func1(x, y): 903 return x + y 904 905 def func2(x, y): 906 if x.sum() > 0: 907 return x + y 908 else: 909 return x - y 910 911 inp = ( 912 torch.tensor(1), 913 torch.tensor(1), 914 ) 915 self._check_equal_ts_ep_converter(func1, inp) 916 917 ep_list = self._check_equal_ts_ep_converter(func2, inp) 918 919 t = inp[0] 920 t -= 1 921 for ep in ep_list[1:]: 922 torch.testing.assert_close( 923 ep.module()(*inp), 924 func2(*inp), 925 ) 926 927 def test_implicit_constant_to_tensor_handling(self): 928 def func1(x): 929 return x + 2 930 931 def func2(x, y): 932 return x * y / (x - 2 * y) + y 933 934 def func3(x): 935 return x + torch.tensor([3]) 936 937 def func4(): 938 val = torch.tensor(float("inf")) 939 return torch.full((10, 10), val) 940 941 def func5(): 942 x = -1 943 return x * torch.ones(1, dtype=torch.float), torch.zeros( 944 1, dtype=torch.float 945 ) 946 947 def func6(x1, x2, x3, x4): 948 return ( 949 x1.numel(), 950 x1.size(), 951 x2.numel(), 952 x2.size(), 953 x3.numel(), 954 x3.size(), 955 x4.numel(), 956 x4.size(), 957 torch.ones(x1.numel()), # Just make sure downstream ops still work. 958 torch.ones(x1.size()), # Just make sure downstream ops still work. 959 ) 960 961 class M1(torch.nn.Module): 962 def __init__(self, value): 963 super().__init__() 964 self.x = torch.tensor(value) 965 966 def forward(self): 967 return self.x.clone() 968 969 class M2(torch.nn.Module): 970 def forward(self, x): 971 return torch.tensor(4) + x 972 973 inp = (torch.randn([2, 2]),) 974 self._check_equal_ts_ep_converter(func1, inp) 975 inp = (torch.randn([2, 2]), torch.randn([2, 2])) 976 self._check_equal_ts_ep_converter(func2, inp) 977 978 inp = (torch.randn([2, 2]),) 979 self._check_equal_ts_ep_converter(func3, inp) 980 981 self._check_equal_ts_ep_converter(func4, ()) 982 self._check_equal_ts_ep_converter(M1(5), ()) 983 984 inp = (torch.randn(2),) 985 self._check_equal_ts_ep_converter(M2(), inp) 986 987 self._check_equal_ts_ep_converter(func5, ()) 988 inp = ( 989 torch.randn([2, 3, 4]).to(torch.int8), 990 torch.randn([2, 3, 4]).to(torch.int32), 991 torch.randn([2, 3, 4]).to(torch.float32), 992 torch.randn([2, 3, 4]).to(torch.float64), 993 ) 994 ep_list = self._check_equal_ts_ep_converter(func6, inp) 995 996 # TODO: Additional check once dynamic shape is supported. 997 # for ep in ep_list: 998 # self.assertEqual( 999 # ep.module()( 1000 # torch.randn([1, 1, 1]).to(torch.int8), 1001 # torch.randn([1, 1, 1]).to(torch.int32), 1002 # torch.randn([1, 1, 1]).to(torch.float32), 1003 # torch.randn([1, 1, 1]).to(torch.float64), 1004 # )[0], 1 1005 # ) 1006 1007 def test_aten_tensor_dtype_int(self): 1008 class M(torch.nn.Module): 1009 def forward(self, x): 1010 y = torch.tensor(1, dtype=torch.int32) 1011 return y + x 1012 1013 ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),)) 1014 for ep in ep_list: 1015 self.assertEqual(len(ep.constants), 1) 1016 1017 def test_aten_tensor_prim_dtype(self): 1018 class M(torch.nn.Module): 1019 def forward(self, x): 1020 y = torch.tensor(1, dtype=x.dtype) 1021 return y + x 1022 1023 ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),)) 1024 for ep in ep_list: 1025 self.assertEqual(len(ep.constants), 1) 1026 1027 def test_aten_tensor_dynamic(self): 1028 class M(torch.nn.Module): 1029 def forward(self, x): 1030 s = x.shape[0] 1031 y = torch.tensor(s) 1032 return y 1033 1034 ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),)) 1035 for ep in ep_list: 1036 self.assertEqual(len(ep.constants), 0) 1037 1038 # TODO: Additional check once dynamic shape is supported. 1039 # for ep in ep_list: 1040 # torch.testing.assert_close( 1041 # ep.module()(torch.ones(4)), 1042 # M()(torch.ones(4)), 1043 # ) 1044 1045 class M(torch.nn.Module): 1046 def forward(self, x): 1047 s = x.shape[0] 1048 y = torch.tensor([s, s * 2, 1]) 1049 return y 1050 1051 ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),)) 1052 # Trace directly inline a tensor constant. 1053 for ep in ep_list[1:]: 1054 self.assertEqual(len(ep.constants), 0) 1055 1056 # TODO: Additional check once dynamic shape is supported. 1057 # for ep in ep_list: 1058 # torch.testing.assert_close( 1059 # ep.module()(torch.ones(4)), 1060 # M()(torch.ones(4)), 1061 # ) 1062 1063 def test_prim_tolist(self): 1064 class Module(torch.nn.Module): 1065 def forward(self, x: torch.Tensor) -> List[int]: 1066 return x.tolist() 1067 1068 inp = (torch.tensor([1, 2, 3]),) 1069 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 1070 1071 class Module(torch.nn.Module): 1072 def forward(self, x: torch.Tensor) -> List[List[int]]: 1073 return x.tolist() 1074 1075 inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),) 1076 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 1077 1078 def test_get_tensor_constants(self): 1079 # Since self.data is only read but not written, it is lifted as 1080 # constant tensors. 1081 class Foo(torch.nn.Module): 1082 def __init__(self) -> None: 1083 super().__init__() 1084 self.data = torch.randn(3, 2) 1085 1086 def forward(self, x: torch.Tensor) -> torch.Tensor: 1087 return x + self.data 1088 1089 class Goo(torch.nn.Module): 1090 def __init__(self) -> None: 1091 super().__init__() 1092 self.data = torch.randn(3, 2) 1093 self.foo = Foo() 1094 1095 def forward(self, x: torch.Tensor) -> torch.Tensor: 1096 return x + self.data + self.foo.data + self.foo(x) 1097 1098 inp = (torch.randn(3, 2),) 1099 goo = Goo() 1100 self._check_equal_ts_ep_converter(goo, inp) 1101 1102 def test_prim_SetAttr(self): 1103 class Module(torch.nn.Module): 1104 def __init__(self) -> None: 1105 super().__init__() 1106 self.data = torch.nn.Buffer(torch.ones(3, 2)) 1107 1108 def forward(self, x: torch.Tensor) -> torch.Tensor: 1109 self.data = self.data + x 1110 return x + x 1111 1112 inp = (torch.ones(3, 2),) 1113 self._check_equal_ts_ep_converter( 1114 Module, inp, ["script"], check_persistent=True 1115 ) 1116 1117 class Module(torch.nn.Module): 1118 def __init__(self) -> None: 1119 super().__init__() 1120 self.data = torch.nn.Buffer(torch.ones(3, 2)) 1121 1122 def forward(self, x: torch.Tensor) -> torch.Tensor: 1123 self.data = self.data + x 1124 return x + self.data 1125 1126 inp = (torch.ones(3, 2),) 1127 self._check_equal_ts_ep_converter( 1128 Module, inp, ["script"], check_persistent=True 1129 ) 1130 1131 # export lifts a tensor constant (self.data) as an input if it is not assigned. 1132 # If it is assigned, export will error and ask users to register it as a buffer. 1133 # In converter, we change tensor constants that are assigned as a buffer automatically, 1134 # since it might be hard to manually register them as buffers. 1135 class Module(torch.nn.Module): 1136 def __init__(self) -> None: 1137 super().__init__() 1138 self.data = torch.ones(3, 2) 1139 1140 def forward(self, x: torch.Tensor) -> torch.Tensor: 1141 self.data = self.data + x 1142 return x + self.data 1143 1144 inp = (torch.ones(3, 2),) 1145 self._check_equal_ts_ep_converter( 1146 Module, 1147 inp, 1148 ["script"], 1149 check_persistent=True, 1150 lifted_tensor_constants=OrderedDict([("data", torch.ones(3, 2))]), 1151 ) 1152 1153 class Module(torch.nn.Module): 1154 def __init__(self) -> None: 1155 super().__init__() 1156 self.count = 0 1157 1158 def forward(self, x: torch.Tensor) -> torch.Tensor: 1159 self.count += 1 1160 return x + self.count 1161 1162 # check_persistent is False since export specializes on non-tensor constants 1163 inp = (torch.ones(3, 2),) 1164 self._check_equal_ts_ep_converter( 1165 Module(), inp, ["script"], check_persistent=False 1166 ) 1167 1168 class M(torch.nn.Module): 1169 def __init__(self) -> None: 1170 super().__init__() 1171 self.count = 0 1172 1173 def forward(self, x): 1174 count1 = self.count 1175 self.count += 1 1176 count2 = self.count 1177 self.count += 1 1178 count3 = self.count 1179 return x + count1 + count2 + count3 1180 1181 inp = (torch.ones(1),) 1182 self._check_equal_ts_ep_converter(M(), inp, ["script"], check_persistent=False) 1183 1184 class M(torch.nn.Module): 1185 def __init__(self) -> None: 1186 super().__init__() 1187 self.w2 = torch.nn.Buffer(torch.ones(1)) 1188 1189 def forward(self, x: torch.Tensor): 1190 self.w2 += 1 1191 return self.w2 1192 1193 inp = (torch.ones(1),) 1194 self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True) 1195 1196 def test_raise_exception(self): 1197 class Module(torch.nn.Module): 1198 def forward(self, x: torch.Tensor, y: int) -> torch.Tensor: 1199 if y > 0: 1200 raise RuntimeError("test") 1201 return x + y 1202 1203 # match non-strict export behavior that errors when the given input leads to 1204 # RaiseException. 1205 with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"): 1206 inp = (torch.randn(3, 2), 1) 1207 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 1208 1209 # Matching non-strict export behavior that only executes 1 if-branch according 1210 # to the given input. 1211 inp = (torch.randn(3, 2), 0) 1212 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 1213 1214 class Module(torch.nn.Module): 1215 def forward(self, x: torch.Tensor, y: int) -> torch.Tensor: 1216 z = x 1217 if y > 0: 1218 raise RuntimeError("test") 1219 # z = x 1220 else: 1221 z = x + y 1222 return x + y + z 1223 1224 # match non-strict export behavior that errors when the given input leads to 1225 # RaiseException. 1226 with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"): 1227 inp = (torch.randn(3, 2), 1) 1228 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 1229 1230 # Matching non-strict export behavior that only executes 1 if-branch according 1231 # to the given input. 1232 inp = (torch.randn(3, 2), 0) 1233 self._check_equal_ts_ep_converter(Module(), inp, ["script"]) 1234 1235 def test_context_manager(self): 1236 class ContextManager: 1237 def __init__(self) -> None: 1238 self.count = 0 1239 return 1240 1241 def __enter__(self): 1242 self.count += 1 1243 return 1244 1245 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 1246 self.count -= 1 1247 return 1248 1249 class M(torch.nn.Module): 1250 def forward(self, x, y): 1251 with ContextManager(): 1252 res = x + y 1253 return res 1254 1255 inp = (torch.ones(3, 3), torch.ones(3, 3)) 1256 self._check_equal_ts_ep_converter(M(), inp) 1257 1258 def test_hidden_input_name(self): 1259 @torch.jit.script 1260 def func1(x): 1261 return x + 1 1262 1263 def func2(*args): 1264 v = torch.cat(args, dim=1) 1265 return v * v 1266 1267 inp = (torch.randn([1, 1]),) 1268 self._check_equal_ts_ep_converter(func1, inp) 1269 1270 inp = (torch.ones(5, 5),) 1271 # Cannot script again. 1272 self._check_equal_ts_ep_converter(torch.ops.aten.relu, inp, ["trace"]) 1273 1274 M = 2 1275 Ns = [4, 2, 1] 1276 empty = torch.tensor([], dtype=torch.double) 1277 values = [empty] + [torch.randn(M, N) for N in Ns] 1278 # Cannot script variable length inputs. 1279 self._check_equal_ts_ep_converter(func2, tuple(values), ["trace"]) 1280 1281 def test_ts2ep_multi_outputs_on_call_ops(self): 1282 class M(torch.nn.Module): 1283 def __init__(self) -> None: 1284 super().__init__() 1285 self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True) 1286 1287 def forward(self, x: torch.Tensor, y: torch.Tensor): 1288 return ( 1289 torch.max(x, dim=0), 1290 torch.topk(x, 3), 1291 torch.sort(x, dim=0), 1292 self.pool(y), 1293 ) 1294 1295 inp = (torch.randn([4, 4]), torch.randn([1, 1, 10, 10])) 1296 self._check_equal_ts_ep_converter(M(), inp) 1297 1298 def test_aten_append_t(self): 1299 class M(torch.nn.Module): 1300 def forward(self, x: List[torch.Tensor]): 1301 out = [] 1302 out.append(x[0] + x[1]) 1303 out.append(x[0] - x[1]) 1304 out1 = torch.cat(out) 1305 out.append(x[0] * x[1]) 1306 out2 = torch.cat(out) 1307 return out, out1, out2 1308 1309 inp = ([torch.ones(2, 3), torch.ones(2, 3)],) 1310 # Trace already unrolls the list. 1311 self._check_equal_ts_ep_converter(M(), inp, ["script"]) 1312 1313 def test_convert_script_object(self): 1314 class M1(torch.nn.Module): 1315 def __init__(self): 1316 super().__init__() 1317 self.tq = _empty_tensor_queue() 1318 1319 def forward(self, x: torch.Tensor): 1320 self.tq.push(x) 1321 torch.ops._TorchScriptTesting.queue_push(self.tq, x.cos()) 1322 return torch.ops._TorchScriptTesting.queue_pop(self.tq), self.tq.pop() 1323 1324 inp = (torch.randn(2, 3),) 1325 self._check_equal_ts_ep_converter(M1(), inp, ["script"]) 1326 1327 def test_ts2ep_with_loop(self): 1328 def func1(x, x_list: List[torch.Tensor]): 1329 a, b, c = x, x, x 1330 for i in range(1, 5, 2): 1331 for k in range(5): 1332 a = a + a + k 1333 b = b + b - k 1334 x_list.append(x_list[k] + x_list[k + 1]) 1335 for k in range(5): 1336 b = b + b - k 1337 c = c + c * k 1338 x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2]) 1339 return x, x_list 1340 1341 def func2(x): 1342 for i in range(x.size(0)): 1343 x = x * x * i 1344 return x 1345 1346 def func3(x): 1347 while x.sum() < 10: 1348 x += x.sin() 1349 return x 1350 1351 inp = ( 1352 torch.tensor(1), 1353 [torch.ones([2, 2]), torch.ones([2, 2]) * 2], 1354 ) 1355 # Trace unrolls the loop. 1356 self._check_equal_ts_ep_converter(func1, inp, ["script"]) 1357 1358 # TODO: (2/N) 1359 # Trace unrolls the loop. 1360 # self._check_equal_ts_ep_converter(func2, inp, ["script"]) 1361 1362 # TODO: (3/N) 1363 # Trace unrolls the loop. 1364 # self._check_equal_ts_ep_converter(func3, inp, ["script"]) 1365 1366 @unittest.skipIf( 1367 IS_WINDOWS, 1368 "Windows does not support qnnpack", 1369 ) 1370 def test_ts2ep_convert_quantized_model(self): 1371 class Standalone(torch.nn.Module): 1372 def __init__(self): 1373 super().__init__() 1374 self.quant = torch.ao.quantization.QuantStub() 1375 self.conv1 = torch.nn.Conv2d(1, 1, 1) 1376 self.conv2 = torch.nn.Conv2d(1, 1, 1) 1377 self.relu = torch.nn.ReLU() 1378 self.dequant = torch.ao.quantization.DeQuantStub() 1379 1380 def forward(self, x): 1381 x = self.quant(x) 1382 x = self.conv1(x) 1383 x = self.conv2(x) 1384 x = self.relu(x) 1385 x = self.dequant(x) 1386 return x 1387 1388 def fuse_model(self): 1389 torch.ao.quantization.fuse_modules( 1390 self, [["conv2", "relu"]], inplace=True 1391 ) 1392 1393 with override_quantized_engine("qnnpack"): 1394 model = Standalone() 1395 model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 1396 model.fuse_model() 1397 torch.ao.quantization.prepare(model, inplace=True) 1398 model(torch.randn(4, 1, 4, 4)) 1399 torch.ao.quantization.convert(model, inplace=True) 1400 1401 # Use customized checking here, because state_dict of quantization will be 1402 # modified by the quantization pass. 1403 inp = (torch.randn(4, 1, 4, 4),) 1404 original_ts_model = torch.jit.script(model) 1405 ts_model = torch.jit.script(model) 1406 converter = TS2EPConverter(ts_model, inp) 1407 ep = converter.convert() 1408 1409 orig_out, _ = pytree.tree_flatten(original_ts_model(*inp)) 1410 ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) 1411 self._check_tensor_list_equal(orig_out, ep_out) 1412 1413 def test_ts2ep_convert_quantized_model_with_opcontext(self): 1414 class M(torch.nn.Module): 1415 def __init__(self, linear_op): 1416 super().__init__() 1417 self.linear_op = linear_op 1418 1419 def forward(self, x): 1420 x = torch.ops.prepacked.linear_clamp_run(x, self.linear_op) 1421 return x 1422 1423 linear_op = torch.ops.prepacked.linear_clamp_prepack( 1424 torch.randn(10, 10), torch.randn(10) 1425 ) 1426 m = M(linear_op) 1427 inp = (torch.randn(1, 10),) 1428 self._check_equal_ts_ep_converter(m, inp, ["script"]) 1429 1430 1431if __name__ == "__main__": 1432 run_tests() 1433