1# Owner(s): ["oncall: jit"] 2 3import copy 4import io 5import os 6import sys 7import unittest 8 9import torch 10import torch.nn as nn 11import torch.nn.functional as F 12from torch.autograd import Function, Variable 13from torch.testing import FileCheck 14 15 16# Make the helper files in test/ importable 17pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 18sys.path.append(pytorch_test_dir) 19import warnings 20 21# Standard library 22from collections import namedtuple 23from itertools import chain 24from typing import Dict, List, Optional, Tuple 25 26from torch import Tensor 27from torch.testing._internal.common_cuda import with_tf32_off 28from torch.testing._internal.common_utils import ( 29 enable_profiling_mode_for_profiling_tests, 30 IS_SANDCASTLE, 31 skipIfCompiledWithoutNumpy, 32 skipIfCrossRef, 33 skipIfTorchDynamo, 34 suppress_warnings, 35 TemporaryFileName, 36) 37from torch.testing._internal.jit_utils import ( 38 _tmp_donotuse_dont_inline_everything, 39 _trace, 40 enable_cpu_fuser, 41 JitTestCase, 42 make_global, 43 RUN_CUDA, 44 RUN_CUDA_MULTI_GPU, 45) 46 47 48if __name__ == "__main__": 49 raise RuntimeError( 50 "This test file is not meant to be run directly, use:\n\n" 51 "\tpython test/test_jit.py TESTNAME\n\n" 52 "instead." 53 ) 54 55 56@skipIfTorchDynamo("Not a suitable test for TorchDynamo") 57class TestTracer(JitTestCase): 58 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 59 def test_large_nbr_kernel_args(self): 60 class Recurrence(nn.Module): 61 def __init__(self, seq_len): 62 super().__init__() 63 self.seq_len = seq_len 64 65 def forward(self, input): 66 input = input.transpose(0, 1) 67 68 # Main loop 69 output = [] 70 for i in range(self.seq_len): 71 b = input[i] * 2 72 output.append(b) 73 74 output = torch.cat(output, 0).view(input.size(0), *output[0].size()) 75 output = output.transpose(0, 1) 76 return output 77 78 input_size = 8 79 batch_size = 2 80 seq_len = 130 81 82 rec = Recurrence(seq_len) 83 input = torch.rand(batch_size, seq_len, input_size) 84 85 torch.cuda.set_device(0) 86 rec = rec.cuda() 87 input = input.cuda() 88 89 traced_rec = torch.jit.trace(rec, (input)) 90 91 def test_trace_legacy_ctor(self): 92 class MyModule(nn.Module): 93 def forward(self, x): 94 return (x + 1, torch.FloatTensor([0])) 95 96 traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2)) 97 98 def test_simple(self): 99 x = torch.tensor([0.4], requires_grad=True) 100 y = torch.tensor([0.7], requires_grad=True) 101 102 def f(x, y): 103 return torch.sigmoid(torch.tanh(x * (x + y))) 104 105 self.checkTrace(f, (x, y)) 106 107 def test_trace_checking_with_global_name(self): 108 class MyClass(torch.nn.Module): 109 def forward(self, xs: List[Tensor]): 110 y = torch.cat(xs, dim=0) 111 return y 112 113 model = MyClass() 114 # Simulate these inputs being in the globals, like they would be if, 115 # e.g. they were defined outermost scope of a script 116 global input1, input2 117 input1 = torch.ones(2, 2) 118 input2 = torch.ones(2, 2) 119 m2 = torch.jit.trace(model, ((input1, input2),)) 120 121 def test_trace_aliased_parameter(self): 122 class M(nn.Module): 123 def __init__(self, x): 124 super().__init__() 125 self.x = nn.Parameter(x) 126 127 def forward(self, y): 128 return self.x + y 129 130 m = M(torch.rand(3, 4)) 131 r = torch.jit.trace(m, m.x) 132 t2 = torch.rand(3, 4) 133 self.assertEqual(r(t2), m.x + t2) 134 135 def test_trace_nested_fn(self): 136 class TracedInlineDecision(torch.nn.Module): 137 def forward(self, x, flag): 138 @torch.jit.script 139 def make_decision(flag, x): 140 if flag: 141 return x 142 else: 143 return torch.zeros_like(x) 144 145 x = torch.neg(x) 146 return make_decision(flag, x) 147 148 decision = TracedInlineDecision() 149 torch.jit.trace( 150 decision, 151 (torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)), 152 check_trace=True, 153 ) 154 155 def test_trace_single_tuple(self): 156 x = torch.tensor(2.0) 157 158 def f2(x): 159 return (x,) 160 161 jit_f2 = torch.jit.trace(f2, x) 162 assert f2(x) == jit_f2(x) # fails 163 164 def test_trace_out_operator_with_two_output(self): 165 example_input = torch.rand(2, 8) 166 out_1, out_2 = torch.cummax(example_input, 1) 167 168 def run_cummax(example_input, out_1, out_2): 169 output_1, output_2 = torch.cummax(example_input, 1, out=(out_1, out_2)) 170 return output_1, output_2 171 172 trace_model = torch.jit.trace(run_cummax, (example_input, out_1, out_2)) 173 174 def test_trace_namedtuple(self): 175 Point = namedtuple("point", ["x", "y"]) 176 177 def f(p): 178 if type(p) is tuple: 179 p = Point(*p) 180 return p.x + p.y 181 182 p = Point(torch.randn(1), torch.randn(1)) 183 traced = torch.jit.trace(f, (p,)) 184 self.assertEqual(f(p), traced(p)) 185 186 def test_trace_topk(self): 187 class M(torch.nn.Module): 188 def forward(self, x, y): 189 return x.topk(y, dim=1)[1] 190 191 mod = M() 192 inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17)) 193 traced_func = torch.jit.trace(mod, inputs) 194 195 test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8)) 196 eager_out = mod(*test_inputs) 197 traced_out = traced_func(*test_inputs) 198 self.assertNotWarn( 199 lambda: traced_func(*test_inputs), 200 "Shouldn't throw slicing related warn here", 201 ) 202 self.assertEqual(eager_out, traced_out) 203 204 test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12)) 205 eager_out = mod(*test_inputs) 206 traced_out = traced_func(*test_inputs) 207 self.assertNotWarn( 208 lambda: traced_func(*test_inputs), 209 "Shouldn't throw slicing related warn here", 210 ) 211 self.assertEqual(eager_out, traced_out) 212 213 def test_typeas_trace_check(self): 214 a = torch.tensor([0.4], requires_grad=True) 215 b = torch.tensor([0.7], requires_grad=True) 216 217 def f(x, y): 218 return x.type_as(y) 219 220 trace = torch.jit.trace(f, (a, b)) 221 222 def test_trace_index(self): 223 x = torch.tensor([0.4], requires_grad=True) 224 y = torch.tensor([0], dtype=torch.int64) 225 226 def fn(x, y): 227 return x[y] 228 229 fn_traced = torch.jit.trace( 230 fn, 231 ( 232 x, 233 y, 234 ), 235 ) 236 237 self.assertEqual(fn(x, y), fn_traced(x, y)) 238 239 # Backwards tracing was broken for indexing by a constant, 240 # because it's internally implemented using as_strided, 241 # and we attempted to trace its derivative (which is not 242 # currently supported.) It currently works because 243 # slice() is now not marked as traceable. 244 def test_trace_index_constant(self): 245 x = torch.tensor([0.4], requires_grad=True) 246 247 def fn(x): 248 return x[0] 249 250 def run(f): 251 y = f(x) 252 grad = torch.autograd.grad(y, x)[0].clone() 253 return y, grad 254 255 traced_fn = torch.jit.trace(fn, torch.ones(1)) 256 self.assertEqual(run(fn), run(traced_fn)) 257 258 def test_index_put(self): 259 ten = torch.zeros(3, 3) 260 mask = torch.tensor( 261 [[True, True, True], [True, False, False], [True, True, False]] 262 ) 263 264 def test_fn(ten, mask): 265 ten[mask] = torch.ones(6) 266 return ten 267 268 traced_test_fn = torch.jit.trace(test_fn, (ten, mask)) 269 270 ten = torch.rand(3, 3) 271 self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask)) 272 273 def test_canonicalize_tensor_iterator(self): 274 x = torch.randn(4, 4) 275 276 def f(x): 277 x = x + 2 278 x = x - 4 279 x = x * 6 280 x = x / 8 281 return x 282 283 traced = torch.jit.trace(f, (x,)) 284 f(x) 285 graph = traced.graph_for(x) 286 # There should be 4 int constants for the right sides of operators, plus one 287 # for the alpha argument for add and sub 288 self.assertTrue(str(traced.graph_for(x)).count(": int = prim::Constant") == 5) 289 290 @suppress_warnings 291 def test_constant(self): 292 x = torch.randn(2, 2, requires_grad=True) 293 294 def f(x): 295 return x.matmul(torch.diag(torch.tensor([2.0, 2.0]))) 296 297 self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),)) 298 299 def test_wrapped_number(self): 300 # Scalar's get converted to 'wrapped' tensors of default tensor type. 301 # Wrapped tensors behave differently in certain promotion operations: 302 # float_tensor * double -> float but wrapped_float * double -> double. 303 # This can cause issues in check-trace if not handled correctly in 304 # `aten::isclose()`. 305 306 def foobar(): 307 x = -10000.0 308 result = x * torch.ones(1, dtype=torch.float) 309 return result 310 311 scripted = torch.jit.trace(foobar, (), check_trace=True) 312 313 def test_inplace_transplant(self): 314 x = torch.tensor([0.0], requires_grad=True) 315 316 def fn(x): 317 y = x.clone() 318 y.add_(2) 319 y.add_(3) 320 return y 321 322 g, _ = torch.jit._get_trace_graph(fn, (x,)) 323 self.run_pass("dce", g) 324 FileCheck().check_count("aten::clone", 1, exactly=True).check_count( 325 "aten::add_", 2, exactly=True 326 ).check_next("return").run(str(g)) 327 self.assertExportImport(g, (x,)) 328 329 def test_inplace_flags(self): 330 class InplaceFn(Function): 331 @staticmethod 332 def forward(ctx, x): 333 ctx.mark_dirty(x) 334 return x.add_(1) 335 336 @staticmethod 337 def backward(ctx, go): 338 return go 339 340 class RegularFn(Function): 341 @staticmethod 342 def forward(ctx, x): 343 return x.add(1) 344 345 @staticmethod 346 def backward(ctx, go): 347 return go 348 349 x = torch.tensor([0.0], requires_grad=True) 350 351 def fn(x): 352 y = RegularFn.apply(x) 353 y = InplaceFn.apply(y) 354 y = InplaceFn.apply(y) 355 y = RegularFn.apply(y) 356 return y 357 358 trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True) 359 self.run_pass("dce", trace_graph) 360 ops = list(trace_graph.nodes()) 361 for op in ops: 362 self.assertTrue(op.hasAttribute("inplace")) 363 inplace_flags = [False, True, True, False] 364 for op, is_inplace in zip(ops, inplace_flags): 365 self.assertEqual(op.i("inplace"), is_inplace) 366 367 def test_inplace_check(self): 368 class MyInplaceFn(Function): 369 @staticmethod 370 def forward(self, x): 371 x.add_(1) 372 self.mark_dirty(x) 373 return x 374 375 @staticmethod 376 def backward(self, grad): 377 return grad 378 379 def fn(x): 380 return MyInplaceFn.apply(x) 381 382 x = torch.randn(5, 5) 383 ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False) 384 with self.assertRaisesRegex(RuntimeError, "inplace MyInplaceFn"): 385 ge(x) 386 387 def test_force_outplace_check_fill(self): 388 def f(x): 389 return torch.empty(x.shape).fill_(7) 390 391 x = torch.randn(10, 15) 392 ft = torch.jit.trace(f, x, _force_outplace=True) 393 self.assertEqual(f(x), ft(x)) 394 395 def test_force_outplace_check_zero(self): 396 def f(x): 397 return torch.empty(x.shape).zero_() 398 399 x = torch.randn(10, 15) 400 ft = torch.jit.trace(f, x, _force_outplace=True) 401 self.assertEqual(f(x), ft(x)) 402 403 def do_trace_size(self, requires_grad): 404 def fn(x): 405 return x.view(x.shape[1] * 2, x.size(0), 2) 406 407 x = torch.randn(5, 2, 4, requires_grad=requires_grad) 408 y = torch.randn(4, 8, 4, requires_grad=requires_grad) 409 410 # Check that it behaves as expected 411 traced_fn = torch.jit.trace(fn, x) 412 self.assertEqual(traced_fn(y), fn(y)) 413 self.assertEqual(traced_fn(x), fn(x)) 414 415 def test_trace_size(self): 416 self.do_trace_size(False) 417 418 # test the different graph_executor path that happens when 419 # gradients are required and sizes are involved 420 def test_trace_size_with_grad(self): 421 self.do_trace_size(True) 422 423 def test_trace_numel(self): 424 def fn(x): 425 return x.numel() 426 427 x = torch.randn(2, 3, 4) 428 y = torch.randn(4, 5, 6) 429 430 traced_fn = torch.jit.trace(fn, x) 431 self.assertEqual(traced_fn(y), fn(y)) 432 self.assertEqual(traced_fn(x), fn(x)) 433 434 def do_trace_arange(self, requires_grad): 435 def arange(x): 436 return torch.arange(x.shape[0]) 437 438 def arange_scalar(x): 439 return torch.arange(12) 440 441 def arange_start_end(x): 442 return torch.arange(start=x.shape[0], end=x.shape[0] + 5) 443 444 x = torch.randn(5, 3, 2, requires_grad=requires_grad) 445 y = torch.randn(8, 2, 4, requires_grad=requires_grad) 446 447 # Check that it behaves as expected 448 traced_arange = torch.jit.trace(arange, x) 449 self.assertEqual(traced_arange(y), arange(y)) 450 self.assertEqual(traced_arange(x), arange(x)) 451 452 traced_arange_scalar = torch.jit.trace(arange_scalar, x) 453 self.assertEqual(traced_arange_scalar(y), arange_scalar(y)) 454 self.assertEqual(traced_arange_scalar(x), arange_scalar(x)) 455 456 traced_arange_start_end = torch.jit.trace(arange_start_end, x) 457 self.assertEqual(traced_arange_start_end(y), arange_start_end(y)) 458 self.assertEqual(traced_arange_start_end(x), arange_start_end(x)) 459 460 def test_trace_arange(self): 461 self.do_trace_arange(False) 462 463 # test the different graph_executor path that happens when 464 # gradients are required and sizes are involved 465 def test_trace_arange_with_grad(self): 466 self.do_trace_arange(True) 467 468 # Test that a trace of torch.full(x.shape) doesn't store the shape as a constant 469 def test_trace_full_dynamic_shape(self): 470 def full_with_shape_like(x): 471 return torch.full(x.shape, 2.0) 472 473 x = torch.randn(3, 4) 474 ge = torch.jit.trace(full_with_shape_like, example_inputs=x) 475 y = torch.randn(2, 7) 476 self.assertEqual(ge(y).shape, y.shape) 477 self.assertEqual(ge(x).shape, x.shape) 478 479 # Test that the trace of setitem doesn't store shapes as constants 480 # Fix https://github.com/pytorch/pytorch/issues/43548 481 def test_trace_slice_setitem_dynamic_shape(self): 482 def slice_setitem(x, y): 483 x[:, 2] = y + 1 484 return x 485 486 x = torch.randn(3, 4) 487 traced = torch.jit.trace(slice_setitem, (x, x[:, 0])) 488 x = torch.randn(10, 5) 489 self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0])) 490 491 # Suppression: we are intentionally slicing a tensor, we don't care that it 492 # will be constantified 493 @suppress_warnings 494 def do_trace_slice(self, requires_grad): 495 def slice(x): 496 results = [] 497 for i in range(4): 498 results.append(x[: x.size(0) - i, i : x.size(2), i:3]) 499 return tuple(results) 500 501 def slice_select(x): 502 results = [] 503 for i in range(4): 504 results.append(x[:, i:, x.size(2) - 5]) 505 return tuple(results) 506 507 x = torch.randn(5, 6, 7, requires_grad=requires_grad) 508 y = torch.randn(7, 8, 9, requires_grad=requires_grad) 509 510 # Check that it behaves as expected 511 traced_slice = torch.jit.trace(slice, x) 512 self.assertEqual(traced_slice(y), slice(y)) 513 self.assertEqual(traced_slice(x), slice(x)) 514 515 traced_slice_select = torch.jit.trace(slice_select, x) 516 self.assertEqual(traced_slice_select(y), slice_select(y)) 517 self.assertEqual(traced_slice_select(x), slice_select(x)) 518 519 def test_trace_slice(self): 520 self.do_trace_slice(False) 521 522 # test the different graph_executor path that happens when 523 # gradients are required and sizes are involved 524 def test_trace_slice_with_grad(self): 525 self.do_trace_slice(True) 526 527 def test_trace_casts(self): 528 casts = [ 529 lambda x: x.byte(), 530 lambda x: x.float(), 531 lambda x: x.cpu(), 532 lambda x: x.to(device="cpu"), 533 lambda x: x.to(dtype=torch.int64), 534 lambda x: x.to(device="cpu", dtype=torch.float), 535 lambda x: x.to(x), 536 ] 537 538 def assertContainsCast(trace): 539 self.assertEqual( 540 sum(n.kind() == "aten::to" for n in trace.graph.nodes()), 1 541 ) 542 543 for cast in casts: 544 trace = torch.jit.trace(cast, torch.randn(2, 2)) 545 assertContainsCast(trace) 546 x = torch.randn(2, 2) 547 self.assertEqual(trace(x), cast(x)) 548 549 def to_tensor(x, y): 550 return x.to(y) 551 552 to_tensor_trace = torch.jit.trace( 553 to_tensor, (torch.randn(2, 2), torch.randn(1, 8)) 554 ) 555 assertContainsCast(to_tensor_trace) 556 x, y = torch.randn(2, 2), torch.randn(1, 10) 557 self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y)) 558 559 @skipIfCompiledWithoutNumpy 560 @skipIfCrossRef 561 def test_trace_warn(self): 562 def fn(x): 563 int(x) # Warning 1. 564 y = x * 1 565 if y: # Warning 2. 566 pass 567 q = [x, x * 4] 568 z = q[y] 569 float(z) # Warning 3. 570 z.tolist() # Warning 4. 571 z.numpy() # Warning 5. 572 for _ in torch.ones(4, 4): # Warning 6. 573 pass 574 return z + 4 575 576 with warnings.catch_warnings(record=True) as warns: 577 traced_fn = torch.jit.trace(fn, torch.tensor([1])) 578 for warn in warns: 579 self.assertIs(warn.category, torch.jit.TracerWarning) 580 warns = [str(w.message) for w in warns] 581 self.assertIn("a Python integer", warns[0]) 582 self.assertIn("a Python boolean", warns[1]) 583 self.assertIn("a Python float", warns[2]) 584 self.assertIn("a Python list", warns[3]) 585 self.assertIn("a NumPy array", warns[4]) 586 self.assertIn("Iterating over", warns[5]) 587 588 def test_trace_tuple(self): 589 def fn(x, y): 590 return x, (x * y[1], x * y[0]) 591 592 x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2)) 593 traced_fn = torch.jit.trace(fn, (x, y)) 594 self.assertEqual(traced_fn(x, y), fn(x, y)) 595 # should be a tuple nested within another tuple 596 FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next( 597 "return" 598 ).run(str(traced_fn.graph)) 599 self.assertExportImport(traced_fn.graph, (x, y)) 600 601 def test_trace_random(self): 602 def f(mean, std): 603 return torch.normal(mean, std) 604 605 traced = torch.jit.trace( 606 f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False 607 ) 608 mean, std = torch.zeros(5, 5), torch.ones(5, 5) 609 with torch.random.fork_rng(devices=[]): 610 output = f(mean, std) 611 traced_output = traced(mean, std) 612 self.assertEqual(output, traced_output) 613 614 def test_trace_tensor_factory(self): 615 def run(**kwargs): 616 inputs_require_grads = kwargs.pop("inputs_require_grads", True) 617 618 def fn(x): 619 return x + torch.ones(2, 3, **kwargs) 620 621 input_kwargs = kwargs.copy() 622 if "out" in input_kwargs: 623 del input_kwargs["out"] 624 input = torch.ones(2, 3, **input_kwargs) 625 self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads) 626 # check we recorded 'ones' and did not just record a constant 627 tfn = torch.jit.trace(fn, input) 628 self.assertTrue("ones" in str(tfn.graph)) 629 630 run() 631 run(dtype=torch.int, inputs_require_grads=False) 632 run(out=torch.tensor([])) 633 if RUN_CUDA: 634 run(device="cuda:0") 635 if RUN_CUDA_MULTI_GPU: 636 run(device="cuda:1") 637 638 def test_trace_indexed_assignment(self): 639 def stuff(x, y): 640 x = x.clone() 641 x[0] = y 642 return x 643 644 example = torch.rand(3, 4) 645 self.checkTrace(stuff, (example, example[0] + 1)) 646 647 # TODO: implement 648 @unittest.expectedFailure 649 def test_output_unflatten(self): 650 """Check that outputs of traced functions retain the original structure and nesting""" 651 652 def fn(x): 653 return ( 654 x * 2, 655 ( 656 x**2, 657 x + 4, 658 (x + 2,), 659 ), 660 x * 4, 661 ) 662 663 self.checkTrace(fn, (torch.randn(2, 2),)) 664 665 def test_input_flatten(self): 666 """Check that inputs to traced functions are flattened""" 667 668 def fn(x, t): 669 y, z = t 670 return x * y * z 671 672 inputs = (torch.randn(1), (torch.randn(1), torch.randn(1))) 673 self.checkTrace(fn, inputs) 674 675 def test_input_dict_empty(self): 676 def test(d): 677 pass 678 679 with self.assertRaises(RuntimeError): 680 self.checkTrace(test, {}) 681 682 def test_input_dict_remembers_keys(self): 683 """Check that the trace remembers which keys were in a dict input""" 684 685 class TestModule(torch.nn.Module): 686 def forward(self, dict_input): 687 return dict_input["x"] 688 689 input_1 = {"x": torch.tensor(1)} 690 m = TestModule() 691 m_traced = torch.jit.trace(m, (input_1,)) 692 self.assertEqual(m_traced(input_1), torch.tensor(1)) 693 694 # should work to change the values and not the keys 695 input_same_key_different_value = {"x": torch.tensor(2)} 696 self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2)) 697 698 # error to use something that doesn't have `x` 699 input_different_key = {"y": torch.tensor(3)} 700 with self.assertRaises(RuntimeError): 701 m_traced(input_different_key) 702 703 # it's okay to have additional elements in the dictionary, so long as 'x' is there 704 input_additional_key = {"x": torch.tensor(4), "y": torch.tensor(3)} 705 self.assertEqual(m_traced(input_additional_key), torch.tensor(4)) 706 707 def test_input_dict_insertion_order(self): 708 """Check that dictionary access doesn't care about insertion order""" 709 710 class TestModule(torch.nn.Module): 711 def forward(self, dict_input): 712 return dict_input["x"], dict_input["y"] 713 714 input_x_then_y = {} 715 input_x_then_y["x"] = torch.tensor(1) 716 input_x_then_y["y"] = torch.tensor(2) 717 718 m = TestModule() 719 m_traced = torch.jit.trace(m, (input_x_then_y,)) 720 721 self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2))) 722 723 input_y_then_x = {} 724 input_y_then_x["y"] = torch.tensor(4) 725 input_y_then_x["x"] = torch.tensor(3) 726 727 self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4))) 728 729 def test_input_dict_recursive(self): 730 class TestModule(torch.nn.Module): 731 def forward(self, dict_input): 732 return dict_input["x"][1] 733 734 input_1 = {"x": {1: torch.tensor(1)}} 735 m = TestModule() 736 m_traced = torch.jit.trace(m, (input_1,)) 737 738 input_2 = {"x": {1: torch.tensor(2)}} 739 self.assertEqual(m_traced(input_2), torch.tensor(2)) 740 741 def test_input_dict_checkTrace_mut(self): 742 def test(d): 743 d["x"].tanh_() 744 return d["x"] 745 746 inputs = {"x": torch.rand(3, 4), "y": torch.rand(3, 4)} 747 self.checkTrace(test, (inputs,), inputs_require_grads=False) 748 749 def test_input_dict_unify(self): 750 def test(d): 751 return d["int"], d["float"] 752 753 inputs = { 754 "int": torch.ones((2, 2), dtype=torch.int32), 755 "float": torch.ones((2, 2), dtype=torch.float32), 756 } 757 self.checkTrace(test, (inputs,), inputs_require_grads=False) 758 759 def test_input_tuple_of_dicts(self): 760 def test(t): 761 d = t[0] 762 return d["x"]["y"] 763 764 inputs = {"x": {"y": torch.rand(2, 3)}} 765 self.checkTrace(test, ((inputs, inputs),), allow_unused=True) 766 767 def test_input_dict_of_dicts(self): 768 def test(d): 769 return d["x"]["y"] 770 771 nested_input = {"y": torch.rand(2, 3)} 772 unified_nested = {"y": torch.rand(3, 2)} 773 inputs = {"x": nested_input, "force_unify": unified_nested} 774 self.checkTrace(test, (inputs,), allow_unused=True) 775 776 def test_input_dict_of_lists(self): 777 def test(d): 778 return d["x"][0] 779 780 inputs = {"x": [torch.rand(3, 2)]} 781 self.checkTrace(test, (inputs,)) 782 783 def test_input_list_toplevel_flatten(self): 784 def test(t1, t2): 785 return torch.add(t1, t2) 786 787 inputs = [torch.ones(2, 2), torch.rand(2, 2)] 788 self.checkTrace(test, inputs) 789 790 def test_input_list_toplevel_flatten_direct(self): 791 class Test(torch.nn.Module): 792 def forward(self, t1, t2): 793 return torch.add(t1, t2) 794 795 inputs = [torch.ones(2, 2), torch.rand(2, 2)] 796 torch.jit.trace(Test(), inputs) 797 798 def test_input_list_of_tuples(self): 799 def test(l): 800 return l[0][0] 801 802 inputs = [(torch.ones(2, 2),)] 803 self.checkTrace(test, (inputs,)) 804 805 def test_input_dict_empty_list(self): 806 def test(d): 807 pass 808 809 inputs = {1: []} 810 with self.assertRaisesRegex(RuntimeError, "List trace"): 811 self.checkTrace(test, (inputs,)) 812 813 def test_input_list_mixed_type(self): 814 def test(d): 815 pass 816 817 inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))] 818 with self.assertRaisesRegex(RuntimeError, "consistent"): 819 self.checkTrace(test, (inputs,)) 820 821 def test_conv(self): 822 x = torch.ones(20, 16, 50, 40) 823 g, outputs, inputs = torch.jit._get_trace_graph( 824 nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True 825 ) 826 m = self.createFunctionFromGraph(g) 827 self.assertEqual(outputs, m(*inputs)) 828 829 def test_max_pool(self): 830 x = torch.rand(20, 16, 10, 10) 831 832 def max_pool2d(x): 833 return F.max_pool2d(x, 2) + 2 834 835 trace = torch.jit.trace(max_pool2d, (x)) 836 graph = trace.graph_for(x) 837 FileCheck().check("aten::max_pool2d(").run(graph) 838 self.assertEqual(max_pool2d(x), trace(x)) 839 840 def test_nested_inplace(self): 841 x = torch.randn(2, 2) 842 g, outputs, inputs = torch.jit._get_trace_graph( 843 lambda x: F.threshold(x, 0, 0, inplace=True), (x,), return_inputs=True 844 ) 845 m = self.createFunctionFromGraph(g) 846 self.assertEqual(outputs, m(*inputs)) 847 FileCheck().check("threshold_").run(str(g)) 848 self.assertExportImport(g, (x,)) 849 850 def test_repeated_input(self): 851 def fn(a, b): 852 return a + b 853 854 ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2) 855 inputs = set(ge.graph.inputs()) 856 # three instead of 2 because the export/import in checkTrace adds a 857 # `self` module argument 858 self.assertTrue(len(inputs) == 3) 859 860 def test_repeated_output(self): 861 def fn(a, b): 862 z = a + b 863 return z, z 864 865 ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)]) 866 tuple_output = list(ge.graph.outputs())[0] 867 tuple_inputs = list(tuple_output.node().inputs()) 868 self.assertTrue(tuple_inputs[0] == tuple_inputs[1]) 869 870 def test_inplace_copy(self): 871 x = torch.randn(4, 4, requires_grad=True) 872 873 def f(x): 874 out = torch.zeros(x.size()) 875 out.copy_(x) 876 return out 877 878 g, outputs, inputs = torch.jit._get_trace_graph(f, (x,), return_inputs=True) 879 self.run_pass("dce", g) 880 m = self.createFunctionFromGraph(g) 881 self.assertEqual(outputs, m(*inputs)) 882 self.assertExportImport(g, (x,)) 883 884 def test_inplace_copy_force_outplace(self): 885 x = torch.randn(4, 4, requires_grad=True) 886 887 def f(x): 888 out = torch.zeros(x.size()) 889 out.copy_(x) 890 return out 891 892 g, outputs, inputs = torch.jit._get_trace_graph( 893 f, (x,), return_inputs=True, _force_outplace=True 894 ) 895 self.run_pass("dce", g) 896 m = self.createFunctionFromGraph(g) 897 self.assertEqual(outputs, m(*inputs)) 898 self.assertExportImport(g, (x,)) 899 FileCheck().check("expand_as").run(str(g)) 900 901 def test_shared_param(self): 902 class MyModule(torch.nn.Module): 903 def __init__(self) -> None: 904 super().__init__() 905 self.b = self.a = nn.Parameter(torch.randn(2, 2)) 906 907 def forward(self, x): 908 return x * self.a + self.b 909 910 m = MyModule() 911 g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),)) 912 self.run_pass("dce", g) 913 self.assertEqual(len(list(g.inputs())), 2) 914 FileCheck().check("mul").check("add").run(str(g)) 915 916 def run_ge_tests(self, optimize, use_cuda): 917 with enable_profiling_mode_for_profiling_tests(): 918 with torch.jit.optimized_execution(optimize): 919 920 def rand(*args): 921 t = torch.rand(*args).float() 922 if use_cuda: 923 t = t.cuda() 924 return t 925 926 self.checkTrace( 927 lambda a, b: a * b + b, [rand(1), rand(1)], [rand(2, 3), rand(2, 3)] 928 ) 929 # trivial identity 930 self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)]) 931 932 def foo(a): 933 t = a * a 934 return t * t, 4 * t 935 936 self.checkTrace(foo, [rand(1)]) 937 # unused input 938 self.checkTrace( 939 lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True 940 ) 941 # test outputs that do not get used in grad 942 self.checkTrace(foo, [rand(1)], drop=1) 943 # test autograd fallback 944 self.checkTrace( 945 lambda a, b: a * b / (a - 2 * b) + b, [rand(1), rand(1)] 946 ) 947 948 def test_ge_unoptimized(self): 949 self.run_ge_tests(False, False) 950 951 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 952 @enable_cpu_fuser 953 def test_ge_optimized(self): 954 with enable_profiling_mode_for_profiling_tests(): 955 self.run_ge_tests(True, False) 956 957 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 958 def test_ge_cuda(self): 959 self.run_ge_tests(True, True) 960 961 # more manual test of graph executor that can be used as a scratchpad 962 def test_ge(self): 963 def foo(a, b): 964 return a * b / (a - b) + b 965 966 V = Variable 967 a, b = V(torch.rand(1)), V(torch.rand(1)) 968 ge = torch.jit.trace(foo, (a, b)) 969 a, b = V(torch.rand(1), requires_grad=True), V( 970 torch.rand(1), requires_grad=True 971 ) 972 (r,) = ge(a, b) 973 da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True) 974 975 l2 = da * db + db * db 976 g2result = torch.autograd.grad(l2, [da, db]) 977 978 r = foo(a, b) 979 da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True) 980 self.assertEqual(da, da2) 981 self.assertEqual(db, db2) 982 l3 = da2 * db2 + db2 * db2 983 g2result2 = torch.autograd.grad(l3, [da2, db2]) 984 self.assertEqual(g2result, g2result2) 985 986 def test_trace_annotation(self): 987 @_trace(torch.rand(1)) 988 def foo(a): 989 return a + a + a 990 991 x = torch.randn(5, 5) 992 self.assertEqual(foo(x), x + x + x) 993 994 @unittest.skipIf(not RUN_CUDA, "calls .cuda()") 995 # By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision. 996 # We want float tensors to be computed at full precision in order to use the default precision 997 @with_tf32_off 998 def test_traced_module_cuda(self): 999 class Model(nn.Module): 1000 def __init__(self, num_features, num_layers): 1001 super().__init__() 1002 self.num_layers = num_layers 1003 layers = [ 1004 [nn.Linear(num_features, num_features), nn.Sigmoid()] 1005 for _ in range(num_layers) 1006 ] 1007 self.submodule = nn.Sequential(*chain(*layers)) 1008 1009 def forward(self, x): 1010 for i in range(self.num_layers): 1011 x = self.submodule[i](x) + x 1012 return x 1013 1014 model = Model(5, 3) 1015 x = torch.randn(2, 5) 1016 traced_model = torch.jit.trace(model, x) 1017 1018 # We're missing some attributes these modules had initially. Make sure we can 1019 # still get the __repr__() 1020 model.__repr__() 1021 1022 # XXX: indexing sequentials is broken 1023 linear_submodule = next(iter(traced_model.submodule._modules.values())) 1024 1025 # All attributes that aren't parameters should raise 1026 with self.assertRaises(AttributeError): 1027 linear_submodule.in_features 1028 linear_submodule.weight 1029 linear_submodule.weight = nn.Parameter( 1030 torch.randn(linear_submodule.weight.shape) 1031 ) 1032 with self.assertRaises(RuntimeError): 1033 del linear_submodule.weight 1034 1035 # Submodules can't be called 1036 with self.assertRaises(RuntimeError): 1037 linear_submodule(x) 1038 1039 # Type casts 1040 linear_submodule.cuda() 1041 traced_model.float().cuda() 1042 cuda_out = traced_model(x.float().cuda()) 1043 traced_model.cpu() 1044 cpu_out = traced_model(x.float()) 1045 self.assertEqual(cpu_out, cuda_out) 1046 traced_model.to("cuda") 1047 cuda_out = traced_model(x.float().cuda()) 1048 traced_model.to("cpu") 1049 cpu_out = traced_model(x.float()) 1050 self.assertEqual(cpu_out, cuda_out) 1051 traced_model.to(torch.get_default_dtype()) 1052 1053 # state_dict + load_state_dict 1054 state = {k: v.clone() for k, v in traced_model.state_dict().items()} 1055 new_state = {k: v.clone().fill_(1) for k, v in state.items()} 1056 out = traced_model(x) 1057 traced_model.load_state_dict(new_state) 1058 out_ones = traced_model(x) 1059 traced_model.load_state_dict(state) 1060 out_state = traced_model(x) 1061 self.assertEqual(out, out_state) 1062 self.assertNotEqual(out, out_ones) 1063 1064 @unittest.skipIf(not RUN_CUDA, "uses cuda") 1065 def test_type_same_device(self): 1066 class Model(torch.nn.Module): 1067 def __init__(self) -> None: 1068 super().__init__() 1069 self.dtype = torch.float16 1070 1071 def forward(self, x=None): 1072 h = x.type(self.dtype) 1073 return h 1074 1075 a = Model() 1076 b = torch.jit.trace( 1077 a, example_inputs=(torch.ones([1], device=torch.device("cuda")),) 1078 ) 1079 FileCheck().check_not("device").run(b.code) 1080 1081 def test_export_no_reorder(self): 1082 def func(a, b): 1083 return a * b / (a - 2 * b) + b 1084 1085 recording_inputs = [ 1086 torch.tensor( 1087 [0.55619788169860839844], dtype=torch.float32, requires_grad=True 1088 ), 1089 torch.tensor( 1090 [0.25947844982147216797], dtype=torch.float32, requires_grad=True 1091 ), 1092 ] 1093 1094 ge1 = torch.jit.trace(func, recording_inputs) 1095 ge2 = self.getExportImportCopy(ge1) 1096 1097 outputs_ge1 = ge1(*recording_inputs) 1098 outputs_ge2 = ge2(*recording_inputs) 1099 1100 grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs) 1101 grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs) 1102 self.assertTrue(outputs_ge1 == outputs_ge2) 1103 self.assertTrue(grad_ge1 == grad_ge2) 1104 1105 def test_python_function(self): 1106 class MyFn(Function): 1107 @staticmethod 1108 def forward(ctx, x): 1109 return x + 1 1110 1111 @staticmethod 1112 def backward(ctx, grad_output): 1113 return grad_output 1114 1115 @_trace(torch.zeros(2)) 1116 def fn(x): 1117 return MyFn.apply(x + 2) + 3 1118 1119 x = torch.tensor([1.0, 2.0, 3.0]) 1120 y = torch.randn(2, 2, requires_grad=True) 1121 fn(x) 1122 fn(y) 1123 1124 def test_python_function_tup(self): 1125 class MyFn(Function): 1126 @staticmethod 1127 def forward(ctx, x): 1128 return x + 1, x - 1 1129 1130 @staticmethod 1131 def backward(ctx, grad_output): 1132 return grad_output, grad_output 1133 1134 @_trace(torch.zeros(2)) 1135 def fn(x): 1136 a, b = MyFn.apply(x + 2) 1137 return a + b + 3 1138 1139 x = torch.tensor([1.0, 2.0, 3.0]) 1140 y = torch.randn(2, 2, requires_grad=True) 1141 fn(x) 1142 fn(y) 1143 1144 def test_trace_detach(self): 1145 def foo(x, w): 1146 return torch.matmul(x, w).detach() 1147 1148 traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5))) 1149 1150 FileCheck().check("matmul").check("detach").run(str(traced.graph)) 1151 x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1152 traced_result = traced(x, w) 1153 self.assertEqual(foo(x, w), traced_result) 1154 self.assertFalse(traced_result.requires_grad) 1155 self.assertIsNone(traced_result.grad_fn) 1156 1157 def test_trace_detach_redispatch(self): 1158 def foo(x, w): 1159 y = torch.matmul(x, w) 1160 assert y.requires_grad 1161 y = y.detach() 1162 # Make sure trace kernel redispatches to the right lower kernel. 1163 assert not y.requires_grad 1164 return y 1165 1166 x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1167 # With `check_trace=True` it will run with `@torch.no_grad()` and break assert. 1168 torch.jit.trace(foo, (x, w), check_trace=False) 1169 1170 def test_trace_detach_inplace(self): 1171 def foo(x, w): 1172 y = torch.matmul(x, w) 1173 y.detach_() 1174 return y 1175 1176 traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5))) 1177 1178 FileCheck().check("matmul").check("detach(").run(str(traced.graph)) 1179 x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1180 traced_result = traced(x, w) 1181 self.assertEqual(foo(x, w), traced_result) 1182 self.assertFalse(traced_result.requires_grad) 1183 self.assertIsNone(traced_result.grad_fn) 1184 1185 def test_trace_detach_inplace_redispatch(self): 1186 def foo(x, w): 1187 y = torch.matmul(x, w) 1188 assert y.requires_grad 1189 y.detach_() 1190 # Make sure trace kernel redispatches to the right lower kernel. 1191 assert not y.requires_grad 1192 return y 1193 1194 x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) 1195 # With `check_trace=True` it will run with `@torch.no_grad()` and break assert. 1196 torch.jit.trace(foo, (x, w), check_trace=False) 1197 1198 def test_trace_slice_full_dim(self): 1199 def foo(x): 1200 return x[0:5, 0] + 1.0 1201 1202 traced = torch.jit.trace(foo, (torch.rand(5, 4),)) 1203 test_x = torch.rand(6, 3) 1204 self.assertEqual(foo(test_x), traced(test_x)) 1205 1206 def test_trace_dict_input(self): 1207 class Bar(torch.nn.Module): 1208 def __init__(self) -> None: 1209 super().__init__() 1210 self.foo = Foo() 1211 1212 def forward(self, a, b): 1213 return self.foo({"a": a, "b": b})["a"] 1214 1215 class Foo(torch.nn.Module): 1216 def forward(self, x): 1217 return {"a": x["a"] * x["b"]} 1218 1219 x = (torch.rand(3), torch.rand(3)) 1220 model = Bar() 1221 self.checkTrace(model, x) 1222 1223 def test_trace_dict_output(self): 1224 class TraceDictStrTensor(torch.nn.Module): 1225 def forward(self, a, b): 1226 return {"a": a, "b": b} 1227 1228 class TraceDictTensorTensor(torch.nn.Module): 1229 def forward(self, a, b): 1230 return {a: b, b: a} 1231 1232 x = (torch.rand(3), torch.rand(3)) 1233 with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"): 1234 torch.jit.trace(TraceDictStrTensor(), x) 1235 1236 traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False) 1237 self.assertEqual(traced_dict_str_mod(*x), {"a": x[0], "b": x[1]}) 1238 1239 traced_dict_tensor_mod = torch.jit.trace( 1240 TraceDictTensorTensor(), x, strict=False 1241 ) 1242 self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]}) 1243 1244 def test_trace_with_tensor_list_output(self): 1245 def f(): 1246 return [torch.zeros(1), torch.zeros(5)] 1247 1248 with self.assertWarnsRegex( 1249 torch.jit.TracerWarning, "cause the trace to be incorrect" 1250 ): 1251 torch.jit.trace(f, []) 1252 traced_non_strict_f = torch.jit.trace(f, [], strict=False) 1253 self.assertEqual(traced_non_strict_f(), f()) 1254 1255 def test_trace_with_number_list_output(self): 1256 def f(): 1257 return [1, 5] 1258 1259 with self.assertRaisesRegex( 1260 RuntimeError, r"Only tensors.+can be output from traced functions" 1261 ): 1262 traced_f = torch.jit.trace(f, []) 1263 1264 def test_trace_with_nested_tensor_list_output(self): 1265 def f(): 1266 return [[torch.zeros(1)], [torch.zeros(5)]] 1267 1268 with self.assertRaisesRegex( 1269 RuntimeError, r"Only tensors.+can be output from traced functions" 1270 ): 1271 traced_f = torch.jit.trace(f, []) 1272 1273 def test_trace_with_nested_strided_tensor_output(self): 1274 @torch.jit.script 1275 def nt_construct(values, kv_lengths): 1276 kv_lengths_list: List[int] = kv_lengths.tolist() 1277 return torch._nested_tensor_from_tensor_list( 1278 list(values.split(kv_lengths_list, dim=0)), None, None, None, None 1279 ) 1280 1281 def f(x, offsets): 1282 kv_lengths = offsets[1:] - offsets[:-1] 1283 return nt_construct(x, kv_lengths).cos() 1284 1285 x = torch.rand(5, 4) 1286 offsets = torch.tensor([0, 2, 5]) 1287 ref = f(x, offsets) 1288 f_t = torch.jit.trace(f, (x, offsets)) 1289 res = f_t(x, offsets) 1290 self.assertEqual(ref, res) 1291 x2 = torch.rand((8, 4)) 1292 offsets2 = torch.tensor([0, 2, 4, 8]) 1293 self.assertEqual(f(x2, offsets2), f_t(x2, offsets2)) 1294 1295 def test_trace_variable_instantiation(self): 1296 def random_foo(x): 1297 return Variable(Variable(x) + 1.0) 1298 1299 random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),)) 1300 1301 x = torch.rand(5, 6) 1302 self.assertEqual(random_foo(x), random_foo_traced(x)) 1303 1304 def test_trace_slice_expr_complete_type(self): 1305 def random_foo(x): 1306 return x + 1.0 1307 1308 random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),)) 1309 1310 @torch.jit.script 1311 def random_bar(x): 1312 return random_foo_traced(x)[0:1] 1313 1314 x = torch.rand(3, 4) 1315 self.assertEqual(random_bar(x), (x + 1)[0:1]) 1316 1317 def test_trace_inline_shape(self): 1318 # testing peephole optimization of size is turned into a constant 1319 # in script fn 1320 1321 @torch.jit.script 1322 def tensor_size(x: torch.Tensor) -> torch.Tensor: 1323 return torch.tensor([x.size()[0]]) 1324 1325 self.assertEqual( 1326 tensor_size( 1327 torch.rand( 1328 15, 1329 ) 1330 ), 1331 torch.tensor([15]), 1332 ) 1333 1334 traced_tensor_size = torch.jit.trace( 1335 tensor_size, 1336 torch.rand( 1337 7, 1338 ), 1339 ) 1340 1341 self.assertEqual( 1342 traced_tensor_size( 1343 torch.rand( 1344 15, 1345 ) 1346 ), 1347 torch.tensor([15]), 1348 ) 1349 1350 @torch.jit.script 1351 def use_device(x): 1352 return torch.zeros_like(x, device=x.device) 1353 1354 def foo(x): 1355 return use_device(x) 1356 1357 traced_tensor_size = torch.jit.trace( 1358 foo, 1359 torch.rand( 1360 7, 1361 ), 1362 ) 1363 self.run_pass("inline", traced_tensor_size.graph) 1364 FileCheck().check("prim::device").run(traced_tensor_size.graph) 1365 1366 def test_trace_save(self): 1367 def fn(x): 1368 return x + 2 1369 1370 def check(func): 1371 with TemporaryFileName() as fname: 1372 func.save(fname) 1373 loaded = torch.jit.load(fname) 1374 input = torch.randn(2, 2) 1375 self.assertEqual(func(input), loaded(input)) 1376 1377 out = torch.jit.trace(fn, (torch.ones(2, 2),)) 1378 check(out) 1379 1380 def test_trace_optioanl_dtype(self): 1381 class Test(torch.nn.Module): 1382 def forward(self): 1383 return torch.arange(5) 1384 1385 traced = torch.jit.trace(Test(), ()) 1386 torch.allclose(traced(), Test()()) 1387 1388 def test_trace_save_load_copy(self): 1389 class Test(torch.nn.Module): 1390 def __init__(self) -> None: 1391 super().__init__() 1392 self.conv = torch.nn.Conv2d(3, 3, 3) 1393 1394 def forward(self, x): 1395 return self.conv(x) 1396 1397 traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224)) 1398 buffer = io.BytesIO() 1399 torch.jit.save(traced, buffer) 1400 buffer.seek(0) 1401 loaded = torch.jit.load(buffer) 1402 # should work 1403 copy.copy(loaded) 1404 copy.deepcopy(loaded) 1405 1406 def test_trace_export_fns(self): 1407 class Foo(torch.nn.Module): 1408 def __init__(self) -> None: 1409 super().__init__() 1410 self.a = 3 1411 1412 @torch.jit.export 1413 def __getstate__(self): 1414 return (3, self.training) 1415 1416 @torch.jit.export 1417 def __setstate__(self, state): 1418 self.a = state[0] 1419 self.training = state[1] 1420 1421 def forward(self, x): 1422 return x + self.a 1423 1424 f = Foo() 1425 1426 traced = torch.jit.trace(f, (torch.rand(3, 4),)) 1427 expected_names = ["__getstate__", "__setstate__"] 1428 1429 def check(mod): 1430 self.assertTrue( 1431 all(name in mod._c._method_names() for name in expected_names) 1432 ) 1433 1434 check(traced) 1435 1436 imported = self.getExportImportCopy(traced) 1437 check(imported) 1438 1439 def test_trace_export_fns_recursive(self): 1440 class Foo(torch.nn.Module): 1441 def __init__(self) -> None: 1442 super().__init__() 1443 self.a = 3 1444 1445 @torch.jit.export 1446 def __getstate__(self): 1447 return (3, self.training) 1448 1449 @torch.jit.export 1450 def __setstate__(self, state): 1451 self.a = state[0] 1452 self.training = state[1] 1453 1454 def forward(self, x): 1455 return x + self.a 1456 1457 class Wrapper(torch.nn.Module): 1458 def __init__(self) -> None: 1459 super().__init__() 1460 self.foo = Foo() 1461 1462 def forward(self, x): 1463 return self.foo(x) 1464 1465 f = Wrapper() 1466 1467 traced = torch.jit.trace(f, (torch.rand(3, 4),)) 1468 expected_names = ["__getstate__", "__setstate__"] 1469 1470 def check(mod): 1471 self.assertTrue( 1472 all(name in mod._c._method_names() for name in expected_names) 1473 ) 1474 1475 check(traced.foo) 1476 1477 imported = self.getExportImportCopy(traced) 1478 check(imported.foo) 1479 1480 # Note that Bar's forward can only be traced, but not scripted 1481 class Bar(nn.Module): 1482 @torch.jit.export 1483 def addTwo(self, x): 1484 return x + 2 1485 1486 def forward(self, input): 1487 return (lambda a: a + 1)(input) # noqa: PLC3002 1488 1489 # When tracing Bar as a submodule, we only want to script the 1490 # exported methods, and we want to keep the forwards still 1491 # being traced. 1492 class WrapperExports(torch.nn.Module): 1493 def __init__(self) -> None: 1494 super().__init__() 1495 self.bar = Bar() 1496 1497 @torch.jit.export 1498 def addOne(self, x): 1499 return x + 1 1500 1501 def forward(self, x): 1502 return self.bar(x) 1503 1504 f = WrapperExports() 1505 1506 traced = torch.jit.trace(f, (torch.rand(3, 4),)) 1507 expected_names = ["addOne"] 1508 check(traced) 1509 1510 def test_trace_autograd_function(self): 1511 class TestFunc(torch.autograd.Function): 1512 @staticmethod 1513 def forward(ctx, input): 1514 return torch.neg(input) 1515 1516 @staticmethod 1517 def backward(ctx, grad_output): 1518 return torch.neg(grad_output) 1519 1520 class TracedModule(torch.nn.Module): 1521 def forward(self, x): 1522 return torch.relu(TestFunc.apply(x)) 1523 1524 class Wrapper(torch.nn.Module): 1525 def __init__(self) -> None: 1526 super().__init__() 1527 self.tm = TracedModule() 1528 1529 def forward(self, x): 1530 return self.tm(x) 1531 1532 traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),)) 1533 1534 def test_trace_multi_output_function(self): 1535 # An autograd.Function with two outputs. 1536 # It swaps inputs so we can check if shape 1537 # handling is correct in TorchScript. 1538 class Foo(torch.autograd.Function): 1539 @staticmethod 1540 def forward(ctx, x, y): 1541 return y, x 1542 1543 @staticmethod 1544 def backward(ctx, du, dv): 1545 return dv, du 1546 1547 class Bar(torch.nn.Module): 1548 def forward(self, x, y): 1549 x = x.relu() 1550 y = y.relu() 1551 z = Foo.apply(x, y) 1552 return z 1553 1554 x = torch.rand(3, 2, dtype=torch.double) 1555 y = torch.rand(1, 2, dtype=torch.double) 1556 1557 # Generate JIT IR. 1558 traced = torch.jit.trace(Bar(), (x, y)) 1559 print(traced.graph) 1560 1561 # Expected output schema of the custom autograd.Function. 1562 schema = ( 1563 "(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), " 1564 "Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) " 1565 "= ^Foo" 1566 ) 1567 1568 # See if expected schema exists. 1569 FileCheck().check(schema).run(traced.graph) 1570 1571 # Also examine if the graph is runnable and produces 1572 # the right result. 1573 u, v = traced(x, y) 1574 self.assertEqual(u, y) 1575 self.assertEqual(v, x) 1576 1577 def test_interpolate_trace(self): 1578 class test(nn.Module): 1579 def __init__(self) -> None: 1580 super().__init__() 1581 self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1) 1582 1583 def forward(self, x): 1584 y = self.conv(x) 1585 w = nn.functional.interpolate( 1586 y, mode="bilinear", align_corners=False, scale_factor=3 1587 ) 1588 return w 1589 1590 f = test() 1591 # no failure 1592 g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),)) 1593 x = torch.zeros(1, 1, 14, 14) 1594 # constants not baked in 1595 self.assertEqual(g(x), f(x)) 1596 1597 @_tmp_donotuse_dont_inline_everything 1598 def test_trace_optional(self): 1599 @torch.jit.script 1600 def test(x: Optional[Tensor]): 1601 if x is None: 1602 return torch.zeros(1) 1603 else: 1604 return x 1605 1606 def test_none(): 1607 return test(None) 1608 1609 def test_tensor(): 1610 return test(torch.zeros(2)) 1611 1612 f_none = torch.jit.trace(test_none, ()) 1613 self.assertEqual(f_none(), torch.zeros(1)) 1614 1615 f_tensor = torch.jit.trace(test_tensor, ()) 1616 self.assertEqual(f_tensor(), torch.zeros(2)) 1617 1618 graph = f_tensor.graph 1619 FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph) 1620 1621 def test_trace_nested_datatypes(self): 1622 @torch.jit.script 1623 def foo(x): 1624 return [[x + 1, x - 1], [x + 2, x - 2]] 1625 1626 def bar(x): 1627 list_stuff = foo(x) 1628 return list_stuff[0][0], list_stuff[1][1] 1629 1630 traced = torch.jit.trace(bar, torch.rand(3, 4)) 1631 x = torch.rand(5, 6) 1632 self.assertEqual(bar(x), traced(x)) 1633 1634 @_tmp_donotuse_dont_inline_everything 1635 def test_call_traced_fn_from_traced_module(self): 1636 @_trace(torch.rand(3, 4)) 1637 def traced_fn(x): 1638 return torch.neg(x) 1639 1640 class TracedModule(torch.nn.Module): 1641 def __init__(self) -> None: 1642 super().__init__() 1643 self.param = torch.nn.Parameter(torch.rand(4, 5)) 1644 1645 def forward(self, x): 1646 return traced_fn(torch.mm(x, self.param)) 1647 1648 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 1649 1650 # Note: neg op from the traced function should be properly inlined 1651 FileCheck().check("aten::mm").check('name="traced_fn"').check_next( 1652 "prim::CallFunction" 1653 ).run(str(tm.graph)) 1654 1655 @_tmp_donotuse_dont_inline_everything 1656 def test_call_traced_module_from_traced_module(self): 1657 class TracedModule1(torch.nn.Module): 1658 def __init__(self) -> None: 1659 super().__init__() 1660 self.param = torch.nn.Parameter(torch.rand(5, 7)) 1661 1662 def forward(self, x): 1663 return torch.mm(x, self.param) 1664 1665 class TracedModule(torch.nn.Module): 1666 def __init__(self) -> None: 1667 super().__init__() 1668 self.param = torch.nn.Parameter(torch.rand(4, 5)) 1669 self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5)) 1670 1671 def forward(self, x): 1672 return self.mod(torch.mm(x, self.param)) + 1.0 1673 1674 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 1675 1676 FileCheck().check("aten::mm").check("prim::CallMethod").check_same( 1677 "forward" 1678 ).check("aten::add").run(str(tm.graph)) 1679 1680 def test_index_put_trace_with_view(self): 1681 @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4)) 1682 def test_index_put(target, indices, rhs): 1683 target[indices] = rhs 1684 return target 1685 1686 FileCheck().check("aten::view").check("index_put_").run( 1687 str(test_index_put.graph) 1688 ) 1689 1690 def test_index_put_trace_without_view(self): 1691 @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4)) 1692 def test_index_put(target, indices, rhs): 1693 target[indices] = rhs 1694 return target 1695 1696 FileCheck().check_not("aten::view").check("index_put_").run( 1697 str(test_index_put.graph) 1698 ) 1699 1700 @suppress_warnings 1701 def test_trace_checker_dot_data(self): 1702 with self.assertRaisesRegex( 1703 torch.jit.TracingCheckError, 1704 r"Tensor-valued Constant nodes differed in value " r"across invocations", 1705 ): 1706 1707 @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]) 1708 def foo(x): 1709 y = x.data 1710 return x + y 1711 1712 @suppress_warnings 1713 def test_trace_checker_control_flow(self): 1714 def foo(x): 1715 for _ in range(x.size(0)): 1716 x = torch.neg(x) 1717 return x 1718 1719 with self.assertRaisesRegex( 1720 torch.jit.TracingCheckError, r"Graphs differed across invocations!" 1721 ): 1722 torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)]) 1723 1724 @suppress_warnings 1725 def test_trace_checker_memoization(self): 1726 with self.assertRaisesRegex( 1727 torch.jit.TracingCheckError, r"Graphs differed across invocations!" 1728 ): 1729 1730 def foo(x): 1731 if not hasattr(foo, "cache"): 1732 foo.cache = torch.neg(x) 1733 return x + foo.cache 1734 1735 traced = torch.jit.trace( 1736 foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)] 1737 ) 1738 1739 def test_trace_checker_slice_lhs(self): 1740 def foo(x): 1741 for i in range(3): 1742 x[i, :] = torch.zeros(4) 1743 return x 1744 1745 self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False) 1746 1747 def test_trace_checker_inplace_on_view(self): 1748 def foo(x): 1749 x.view(-1).add_(-x.view(-1)) 1750 return x 1751 1752 with self.assertWarnsRegex( 1753 torch.jit.TracerWarning, 1754 "Output nr 1. of the traced function does not match the " 1755 "corresponding output of the Python function", 1756 ): 1757 torch.jit.trace( 1758 foo, 1759 torch.rand(3, 4), 1760 check_inputs=[torch.rand(5, 6)], 1761 _force_outplace=True, 1762 ) 1763 1764 def test_lhs_index_fails(self): 1765 def foo(x): 1766 x[0, 1] = 4 1767 return x 1768 1769 with self.assertWarnsRegex( 1770 torch.jit.TracerWarning, "cause the trace to be incorrect" 1771 ): 1772 torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True) 1773 1774 def test_lhs_index_trivial(self): 1775 def foo(y, x): 1776 y[...] = x 1777 return y 1778 1779 self.checkTrace( 1780 foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False 1781 ) 1782 1783 def test_inplace_warn(self): 1784 def foo(x): 1785 x.view(-1).add_(-x.view(-1)) 1786 return x 1787 1788 with self.assertWarnsRegex( 1789 torch.jit.TracerWarning, "cause the trace to be incorrect" 1790 ): 1791 torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True) 1792 1793 @suppress_warnings 1794 def test_trace_checker_dropout_train(self): 1795 def foo(x): 1796 return torch.dropout(x, p=0.5, train=True) 1797 1798 with self.assertWarnsRegex( 1799 torch.jit.TracerWarning, 1800 "Output nr 1. of the traced function does not match the " 1801 "corresponding output of the Python function", 1802 ): 1803 torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]) 1804 1805 with self.assertWarnsRegex( 1806 torch.jit.TracerWarning, "Trace had nondeterministic nodes" 1807 ): 1808 torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]) 1809 1810 def test_trace_checker_dropout_notrain(self): 1811 input = torch.rand(3, 4) 1812 1813 @_trace(input) 1814 def foo(x): 1815 return torch.dropout(x, p=0.5, train=False) 1816 1817 self.assertEqual(foo(input), input) 1818 1819 def test_trace_contiguous(self): 1820 def foo(x): 1821 return x[:, :, ::2].contiguous().view(12) 1822 1823 x = torch.rand(2, 3, 4) 1824 traced = torch.jit.trace(foo, (x,)) 1825 y = traced(x) 1826 self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr()) 1827 1828 # This tests the logic in THPVariable_contiguous. There is short-circuiting 1829 # code that prevents us from even getting to VariableType::contiguous, since 1830 # it is an optimization that prevents us from acquiring the GIL for touching 1831 # the device. We needed to add the tracing logic directly into the 1832 # THPVariable_contiguous function only for the path where we are skipping 1833 # dispatch into contiguous. We should see an aten::contiguous in this trace! 1834 def test_trace_contiguous_short_circuit(self): 1835 def foo(x): 1836 return x.contiguous() 1837 1838 x = torch.rand(2, 3, 4) 1839 traced = torch.jit.trace(foo, (x,)) 1840 FileCheck().check("aten::contiguous").run(str(traced.graph)) 1841 1842 def test_trace_inverse(self): 1843 def foo(x): 1844 return ~x 1845 1846 foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8)) 1847 eg = torch.zeros(3, dtype=torch.uint8) 1848 self.assertEqual(foo_traced(eg), foo(eg)) 1849 1850 def test_trace_modulelist(self): 1851 class MySubmod(torch.nn.Module): 1852 def __init__(self) -> None: 1853 super().__init__() 1854 self.relu = torch.nn.ReLU() 1855 1856 def forward(self, x): 1857 return self.relu(x) 1858 1859 class MyMod(torch.nn.Module): 1860 def __init__(self) -> None: 1861 super().__init__() 1862 self.ml = torch.nn.ModuleList([MySubmod(), MySubmod()]) 1863 1864 def forward(self, x): 1865 for mod in self.ml: 1866 x = mod(x) 1867 return x 1868 1869 traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),)) 1870 1871 def test_trace_fork_join_and_module(self): 1872 class MySubmod(torch.nn.Module): 1873 def __init__(self) -> None: 1874 super().__init__() 1875 self.relu = torch.nn.ReLU() 1876 1877 def forward(self, x): 1878 return self.relu(x), torch.neg(x) 1879 1880 class Mod(torch.nn.Module): 1881 def __init__(self) -> None: 1882 super().__init__() 1883 self.ml = torch.nn.ModuleList([MySubmod() for i in range(2)]) 1884 1885 def forward(self, x): 1886 futs = [] 1887 for i in range(2): 1888 futs.append(torch.jit._fork(self.ml[i], x)) 1889 1890 results = [] 1891 for i in range(2): 1892 results.append(torch.jit._wait(futs[i])[0]) 1893 1894 return torch.stack(results) 1895 1896 m = Mod() 1897 traced = torch.jit.trace(m, torch.rand(3, 4)) 1898 1899 def test_trace_invert_module_hierarchy(self): 1900 class MySubmod(torch.nn.Module): 1901 def __init__(self) -> None: 1902 super().__init__() 1903 self.relu = torch.nn.ReLU() 1904 1905 def forward(self, x): 1906 return self.relu(x), torch.neg(x) 1907 1908 class MyFunctionalMod(torch.nn.Module): 1909 def forward(self, x, submod): 1910 return submod(x) 1911 1912 class Mod(torch.nn.Module): 1913 def __init__(self) -> None: 1914 super().__init__() 1915 self.sm = MySubmod() 1916 self.fm = MyFunctionalMod() 1917 1918 def forward(self, x): 1919 return self.fm(x, self.sm) 1920 1921 torch.jit.trace(Mod(), (torch.rand(3, 4),)) 1922 1923 @skipIfCrossRef 1924 def test_trace_records_names(self): 1925 def foo(bar, baz): 1926 baz = bar + 3 1927 quick_brown_fox = torch.neg(baz) 1928 for _ in range(20): 1929 yeet = quick_brown_fox - 3.14 1930 return yeet 1931 1932 traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3))) 1933 graph_str = str(traced.graph) 1934 assert "bar" in graph_str 1935 assert "baz" in graph_str 1936 assert "quick_brown_fox" in graph_str 1937 1938 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1939 def test_tracing_hooks(self): 1940 class Net(nn.Module): 1941 def forward(self, x): 1942 return x + x 1943 1944 def test_hook(is_post_hook, hook, fc): 1945 n = Net() 1946 if is_post_hook: 1947 n.register_forward_hook(hook) 1948 else: 1949 n.register_forward_pre_hook(hook) 1950 1951 module = torch.jit.trace(n, (torch.tensor(1.0),)) 1952 1953 eager_input = torch.tensor(1.0) 1954 eager_out = n(eager_input) 1955 1956 fc.run(module.forward.graph) 1957 input = torch.tensor(1.0) 1958 output = module(input) 1959 1960 self.assertEqual(input, eager_input) 1961 self.assertEqual(output, eager_out) 1962 1963 def hook_no_return(mod, input, output): 1964 input[0].add_(1) 1965 output.sub_(1) 1966 1967 fc = FileCheck().check("add(").check("add_(").check("sub_(") 1968 test_hook(True, hook_no_return, fc) 1969 1970 def hook_return(mod, input, output): 1971 input[0].add_(1) 1972 return output - 3 1973 1974 fc = FileCheck().check("add(").check("add_(").check("sub(") 1975 test_hook(True, hook_return, fc) 1976 1977 b = torch.tensor(3.0) 1978 1979 def captured_hook(mod, input, output): 1980 return output - b 1981 1982 fc = FileCheck().check("add(").check("sub(") 1983 test_hook(True, captured_hook, fc) 1984 1985 def pre_hook_no_ret(mod, input): 1986 input[0].add_(3) 1987 1988 fc = FileCheck().check("add_(").check("add(") 1989 test_hook(False, pre_hook_no_ret, fc) 1990 1991 def pre_hook_ret(mod, input): 1992 return input[0] - 4 1993 1994 fc = FileCheck().check("sub(").check("add(") 1995 test_hook(False, pre_hook_ret, fc) 1996 1997 def test_tracing_backward_hook_error(self): 1998 class Net(nn.Module): 1999 def forward(self, x): 2000 return x + x 2001 2002 n = Net() 2003 2004 def backward_hook(module, grad_input, grad_output): 2005 pass 2006 2007 n.register_backward_hook(backward_hook) 2008 with self.assertRaisesRegex(Exception, "backward hooks assigned"): 2009 torch.jit.trace(n, (torch.tensor(1.0),)) 2010 2011 def test_tracing_multiple_methods(self): 2012 class Net(nn.Module): 2013 def __init__(self) -> None: 2014 super().__init__() 2015 self.conv = nn.Conv2d(1, 1, 3) 2016 2017 def forward(self, x): 2018 return self.conv(x) 2019 2020 def weighted_kernel_sum(self, weight): 2021 return weight * self.conv.weight 2022 2023 example_weight = torch.rand(1, 1, 3, 3) 2024 example_forward_input = torch.rand(1, 1, 3, 3) 2025 inputs = { 2026 "forward": example_forward_input, 2027 "weighted_kernel_sum": example_weight, 2028 } 2029 n = Net() 2030 module = torch.jit.trace_module(n, inputs) 2031 2032 check_inputs = [] 2033 for i in range(2): 2034 check_weight = torch.rand(1, 1, 3, 3) 2035 check_forward_input = torch.rand(1, 1, 3, 3) 2036 check_inputs.append( 2037 {"forward": check_forward_input, "weighted_kernel_sum": check_weight} 2038 ) 2039 module = torch.jit.trace_module( 2040 n, inputs, check_trace=True, check_inputs=check_inputs 2041 ) 2042 self.assertTrue(module._c._has_method("forward")) 2043 self.assertTrue(module._c._has_method("weighted_kernel_sum")) 2044 2045 module = torch.jit.trace(n.forward, example_forward_input) 2046 module = torch.jit.trace( 2047 n.forward, 2048 example_forward_input, 2049 check_trace=True, 2050 check_inputs=[example_forward_input], 2051 ) 2052 with self.assertRaisesRegex( 2053 AttributeError, 2054 "trace doesn't support compiling individual module's functions", 2055 ): 2056 module = torch.jit.trace(n.weighted_kernel_sum, inputs) 2057 2058 def test_tensor_with_grad_as_constant(self): 2059 param = torch.randn(3).requires_grad_() 2060 x = torch.randn(3) 2061 2062 def f(x): 2063 return x + param 2064 2065 with self.assertRaisesRegex( 2066 RuntimeError, "Cannot insert a Tensor that requires grad as a constant" 2067 ): 2068 torch.jit.trace(f, x) 2069 2070 def test_non_tensor_tracing(self): 2071 def f(x): 2072 return x + param # noqa: F821 2073 2074 with self.assertRaisesRegex( 2075 RuntimeError, r"Type 'Tuple\[int\]' cannot be traced" 2076 ): 2077 torch.jit.trace(f, (1,)) 2078 2079 def test_trace_skip_none_submodule(self): 2080 class TestModule(torch.nn.Module): 2081 def __init__(self) -> None: 2082 super().__init__() 2083 self.submod = torch.nn.Linear(3, 4) 2084 self.submod = None 2085 2086 def forward(self, inputs): 2087 return inputs 2088 2089 m = TestModule() 2090 tm = torch.jit.trace(m, torch.tensor(1.0)) 2091 self.assertFalse(hasattr(tm, "submod")) 2092 2093 def test_trace_with_conditional_property(self): 2094 class Net(nn.Module): 2095 def __init__(self, attr=None): 2096 super().__init__() 2097 if attr is not None: 2098 self._attr = attr 2099 self.attr_name = "_attr" 2100 2101 @property 2102 def attr(self): 2103 return getattr(self, self.attr_name) 2104 2105 def forward(self, x): 2106 return x 2107 2108 x = torch.ones(1) 2109 torch.jit.trace(Net(), x) 2110 2111 def test_trace_func_argument_names_captured(self): 2112 def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor: 2113 return first_arg + second_arg 2114 2115 traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1))) 2116 FileCheck().check("first_arg").check_next("second_arg").run( 2117 str(traced_fn.graph) 2118 ) 2119 2120 def test_trace_partial_func_argument_names_captured(self): 2121 def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor: 2122 return first_arg + second_arg 2123 2124 traced_fn = torch.jit.trace(fn, (torch.ones(1),)) 2125 FileCheck().check("first_arg").check_not("second_arg").run(str(traced_fn.graph)) 2126 2127 def test_trace_module_argument_names_captured(self): 2128 class TestModule(nn.Module): 2129 def __init__(self) -> None: 2130 super().__init__() 2131 self.conv = nn.Conv2d(1, 1, 3) 2132 2133 def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor): 2134 return self.conv(first_arg) + second_arg 2135 2136 m = TestModule() 2137 example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3)) 2138 2139 # Explicitly tracing module's forward method 2140 traced_module_forward = torch.jit.trace(m.forward, example_input) 2141 FileCheck().check("first_arg").check_next("second_arg").run( 2142 str(traced_module_forward.graph) 2143 ) 2144 2145 # Tracing module's directly 2146 traced_module = torch.jit.trace(m, example_input) 2147 FileCheck().check("first_arg").check_next("second_arg").run( 2148 str(traced_module.graph) 2149 ) 2150 2151 def test_trace_checking_with_deprecated_name(self): 2152 class MyClass(torch.nn.Module): 2153 def __init__(self) -> None: 2154 super(MyClass, self).__init__() 2155 2156 def forward(self, x, y, **deprecated_arguments): 2157 if len(deprecated_arguments) > 0: 2158 raise RuntimeError( 2159 f"Got unexpected arguments: {deprecated_arguments}" 2160 ) 2161 return x + y 2162 2163 model = MyClass() 2164 m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1))) 2165 m3 = torch.jit.trace( 2166 model, 2167 example_kwarg_inputs={"x": torch.ones(1), "y": torch.ones(1)}, 2168 strict=False, 2169 ) 2170 2171 def test_trace_with_tuple_tensor(self): 2172 class MyClass(torch.nn.Module): 2173 def __init__(self) -> None: 2174 super(MyClass, self).__init__() 2175 2176 def forward(self, x, y): 2177 return x + y[0] + y[1] 2178 2179 model = MyClass() 2180 traced_model = torch.jit.trace( 2181 model, (torch.ones(1), (torch.ones(1), torch.ones(1))) 2182 ) 2183 input_dict = { 2184 "x": torch.tensor([2, 3]), 2185 "y": (torch.tensor([5, 6]), torch.tensor([7, 8])), 2186 } 2187 self.assertEqual(model(**input_dict), traced_model(**input_dict)) 2188 traced_model = torch.jit.trace( 2189 model, 2190 example_kwarg_inputs={ 2191 "x": torch.ones(1), 2192 "y": (torch.ones(1), torch.ones(1)), 2193 }, 2194 ) 2195 self.assertEqual(model(**input_dict), traced_model(**input_dict)) 2196 2197 def test_trace_no_duplicated_lifted_input_output(self): 2198 class Normalize(nn.Module): 2199 def __init__(self) -> None: 2200 super().__init__() 2201 self.norm = nn.GroupNorm(num_groups=32, num_channels=32) 2202 2203 def forward(self, x, y): 2204 if y is None: 2205 y = x 2206 else: 2207 y = self.norm(y) 2208 y = y * 2 2209 return y 2210 2211 class G(nn.Module): 2212 def __init__(self) -> None: 2213 super().__init__() 2214 self.norm = Normalize() 2215 2216 def forward(self, x): 2217 A = self.norm(x, None) 2218 B = F.relu(A) 2219 return A, B 2220 2221 class Net(nn.Module): 2222 def __init__(self) -> None: 2223 super().__init__() 2224 self.g = G() 2225 self.norm_1 = Normalize() 2226 2227 def forward(self, x): 2228 hs = self.g(x) 2229 A, B = hs 2230 h = self.norm_1(B, A) 2231 return h 2232 2233 net = Net() 2234 net = net.eval() 2235 x = torch.randn(1, 32, 16, 16) 2236 traced = torch.jit.trace(net, x) 2237 FileCheck().check_not("prim::TupleUnpack").run(str(traced.graph)) 2238 2239 2240@skipIfTorchDynamo("Not a suitable test for TorchDynamo") 2241class TestMixTracingScripting(JitTestCase): 2242 def test_trace_script(self): 2243 @torch.jit.script 2244 def func1(x: Tuple[Tensor, Tensor]) -> Tensor: 2245 return x[0] + x[1] 2246 2247 @torch.jit.script 2248 def func2(x: List[Tensor]) -> Tensor: 2249 return x[0] + x[1] 2250 2251 a = torch.randn(5) 2252 b = torch.randn(5) 2253 2254 self.checkTrace(func1, ((a, b),)) 2255 self.checkTrace(func2, ((a, b),)) 2256 2257 @torch.jit.script 2258 def func3( 2259 x: Tensor, method: str = "bilinear", align_corners: bool = True 2260 ) -> Tensor: 2261 hw = x.shape[2:4] 2262 return F.interpolate(x, hw, mode=method, align_corners=align_corners) 2263 2264 inp = torch.rand(1, 3, 6, 6) 2265 self.checkTrace(func3, (inp,)) 2266 2267 @torch.jit.script 2268 def func4(x: Tensor, a: List[Optional[str]]) -> Tensor: 2269 if len(a) == 2: 2270 return x + 2 2271 else: 2272 return x 2273 2274 def test_trace_mixed_by_script_with_dict_output(self): 2275 @torch.jit.script 2276 def return_dict(input: torch.Tensor) -> Dict[str, torch.Tensor]: 2277 return {"foo": input + 1} 2278 2279 class TraceModule(torch.nn.Module): 2280 def forward(self, input): 2281 dict = return_dict(input) 2282 return dict["foo"] + dict["foo"] 2283 2284 x = torch.ones(1) 2285 tm = torch.jit.trace(TraceModule(), x) 2286 self.assertEqual(tm(x), x + 1 + x + 1) 2287 2288 def test_trace_of_script(self): 2289 @torch.jit.script 2290 def foo(a, c): 2291 b = 0.0 2292 if bool(a == 0.0): 2293 b = 1.0 2294 return b + c 2295 2296 a = torch.ones(1, dtype=torch.float) 2297 2298 @_trace(torch.zeros(1, dtype=torch.float)) 2299 def use(b): 2300 return foo(b - 1.0, a) + 1.0 2301 2302 # test we propagated shapes through the function 2303 self.assertTrue("Dynamic" not in str(use.graph)) 2304 2305 self.assertEqual(3, use(torch.ones(1, dtype=torch.float))) 2306 self.assertEqual(2, use(torch.zeros(1, dtype=torch.float))) 2307 2308 def test_trace_with_size(self): 2309 @_trace(torch.zeros(1, 1)) 2310 def foo(x): 2311 return x + 1 2312 2313 @torch.jit.script 2314 def bar(x): 2315 y = int(foo(x)) 2316 if 1 == 1: 2317 y = 7 2318 return y + 1 2319 2320 self.assertEqual(8, bar(torch.ones(1, 1))) 2321 2322 def test_tracing_slicing(self): 2323 @_trace(torch.zeros(10)) 2324 def foo_trace(x): 2325 return x[-5:-3] 2326 2327 @torch.jit.script 2328 def foo_script(x): 2329 return x[-5:-3] 2330 2331 def foo(x): 2332 return x[-5:-3] 2333 2334 a = torch.arange(0, 8) 2335 b = torch.arange(0, 20) 2336 self.assertEqual(foo_trace(a), foo_script(a)) 2337 self.assertEqual(foo_trace(a), foo(a)) 2338 self.assertNotEqual(foo_trace(a), foo_trace(b)) 2339 2340 def test_tracing_indexing(self): 2341 @_trace(torch.zeros(10)) 2342 def foo_trace(x): 2343 return x[-2] 2344 2345 @torch.jit.script 2346 def foo_script(x): 2347 return x[-2] 2348 2349 def foo(x): 2350 return x[-2] 2351 2352 a = torch.arange(0, 8) 2353 b = torch.arange(0, 20) 2354 self.assertEqual(foo_script(a), foo_trace(a)) 2355 self.assertEqual(foo_trace(a), foo(a)) 2356 self.assertNotEqual(foo_trace(a), foo_trace(b)) 2357 2358 def test_trace_hierarchy(self): 2359 # Test that we preserve the module hierarchy for a ScriptModule 2360 # submodule during tracing 2361 2362 class AnotherScriptMod(torch.jit.ScriptModule): 2363 def __init__(self) -> None: 2364 super().__init__() 2365 self.param = torch.nn.Parameter(torch.rand(1, 2, 3)) 2366 2367 @torch.jit.script_method 2368 def bar(self): 2369 return torch.zeros(4, 5) 2370 2371 class SomeScriptMod(torch.jit.ScriptModule): 2372 def __init__(self) -> None: 2373 super().__init__() 2374 self.asm = AnotherScriptMod() 2375 2376 @torch.jit.script_method 2377 def foo(self): 2378 return torch.zeros(3, 4) 2379 2380 @torch.jit.script_method 2381 def bar(self): 2382 return torch.zeros(4, 3) 2383 2384 class TraceMe(torch.nn.Module): 2385 def __init__(self) -> None: 2386 super().__init__() 2387 self.ssm = SomeScriptMod() 2388 2389 def forward(self, x): 2390 return self.ssm.bar() + x 2391 2392 orig = TraceMe() 2393 traced = torch.jit.trace(orig, (torch.rand(4, 3),)) 2394 # for each of these checks, check that *BOTH* the underlying 2395 # _C.ScriptModule object has the expected method/param, as well as the 2396 # Python object that wraps it. 2397 self.assertTrue(traced.ssm._c._has_method("foo")) 2398 self.assertTrue(hasattr(traced.ssm, "foo")) 2399 2400 imported = self.getExportImportCopy(traced) 2401 2402 self.assertTrue(imported.ssm._c._has_method("foo")) 2403 self.assertTrue(hasattr(imported.ssm, "foo")) 2404 2405 self.assertTrue(imported.ssm.asm._c._has_method("bar")) 2406 self.assertTrue(hasattr(imported.ssm.asm, "bar")) 2407 2408 self.assertTrue(hasattr(imported.ssm.asm, "param")) 2409 2410 def test_trace_parameter(self): 2411 class Param(nn.Module): 2412 def __init__(self) -> None: 2413 super().__init__() 2414 self.register_parameter("bias", nn.Parameter(torch.empty(4, 4))) 2415 2416 def forward(self, x): 2417 return x 2418 2419 class M3(torch.jit.ScriptModule): 2420 def __init__(self, model): 2421 super().__init__() 2422 self.traced = torch.jit.trace(model, (torch.rand(3, 3))) 2423 2424 @torch.jit.script_method 2425 def forward(self, x): 2426 return self.traced(x) 2427 2428 class M2(nn.Module): 2429 def __init__(self, model): 2430 super().__init__() 2431 self.module = M3(model) 2432 2433 def forward(self, x): 2434 return self.module(x) 2435 2436 class M1(torch.jit.ScriptModule): 2437 def __init__(self, model): 2438 super().__init__() 2439 self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3))) 2440 2441 @torch.jit.script_method 2442 def forward(self, x): 2443 return self.traced(x) 2444 2445 with torch.jit.optimized_execution(False): 2446 module = M1(Param()) 2447 f = io.BytesIO() 2448 torch.jit.save(module, f) 2449 2450 @_tmp_donotuse_dont_inline_everything 2451 def test_call_script_fn_from_traced_module(self): 2452 @torch.jit.script 2453 def scripted_fn(x): 2454 return torch.neg(x) 2455 2456 class TracedModule(torch.nn.Module): 2457 def __init__(self) -> None: 2458 super().__init__() 2459 self.param = torch.nn.Parameter(torch.rand(4, 5)) 2460 2461 def forward(self, x): 2462 return scripted_fn(torch.mm(x, self.param)) 2463 2464 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 2465 FileCheck().check("aten::mm").check('name="scripted_fn"').check( 2466 "prim::CallFunction" 2467 ).run(str(tm.graph)) 2468 2469 @_tmp_donotuse_dont_inline_everything 2470 def test_call_script_module_from_traced_module(self): 2471 class ScriptMod(torch.jit.ScriptModule): 2472 def __init__(self) -> None: 2473 super().__init__() 2474 self.param_foo = torch.nn.Parameter(torch.rand(5, 7)) 2475 2476 @torch.jit.script_method 2477 def forward(self, x): 2478 return torch.mm(x, self.param_foo) 2479 2480 class TracedModule(torch.nn.Module): 2481 def __init__(self) -> None: 2482 super().__init__() 2483 self.param = torch.nn.Parameter(torch.rand(4, 5)) 2484 self.mod = ScriptMod() 2485 2486 def forward(self, x): 2487 return self.mod(torch.mm(x, self.param)) + 1.0 2488 2489 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 2490 2491 FileCheck().check("aten::mm").check("prim::CallMethod").check_same( 2492 "forward" 2493 ).check("aten::add").run(str(tm.graph)) 2494 2495 @_tmp_donotuse_dont_inline_everything 2496 def test_call_traced_fn_from_script_fn(self): 2497 @_trace(torch.rand(3, 4)) 2498 def traced_fn(x): 2499 return torch.neg(x) 2500 2501 @torch.jit.script 2502 def script_fn(x): 2503 return traced_fn(x) + 1 2504 2505 FileCheck().check("prim::CallFunction").check("aten::add").run( 2506 str(script_fn.graph) 2507 ) 2508 2509 def test_call_traced_mod_from_script_fn(self): 2510 with self.assertRaisesRegex( 2511 RuntimeError, 2512 "Cannot call a ScriptModule that is not a submodule of the caller", 2513 ): 2514 2515 class TracedModule(torch.nn.Module): 2516 def forward(self, x): 2517 return torch.mm(x, torch.zeros(4, 3)) 2518 2519 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 2520 2521 @torch.jit.script 2522 def script_fn(x): 2523 return tm(x) + 1 2524 2525 @_tmp_donotuse_dont_inline_everything 2526 def test_call_tracing_fn_from_script_module(self): 2527 @_trace(torch.rand(3, 3)) 2528 def traced_fn(x): 2529 return torch.neg(x) 2530 2531 class ScriptMod(torch.jit.ScriptModule): 2532 def __init__(self) -> None: 2533 super().__init__() 2534 self.param = torch.nn.Parameter(torch.rand(4, 3)) 2535 2536 @torch.jit.script_method 2537 def forward(self, x): 2538 return traced_fn(torch.mm(x, self.param)) 2539 2540 sm = ScriptMod() 2541 FileCheck().check("aten::mm").check("prim::CallFunction").run( 2542 str(sm.forward.graph) 2543 ) 2544 2545 @_tmp_donotuse_dont_inline_everything 2546 def test_call_tracing_mod_from_script_module(self): 2547 class TracedMod(torch.nn.Module): 2548 def __init__(self) -> None: 2549 super().__init__() 2550 self.param = torch.nn.Parameter(torch.rand(3, 5)) 2551 2552 def forward(self, x): 2553 return torch.mm(x, self.param) 2554 2555 class ScriptMod(torch.jit.ScriptModule): 2556 def __init__(self) -> None: 2557 super().__init__() 2558 self.param = torch.nn.Parameter(torch.rand(4, 3)) 2559 self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3)) 2560 2561 @torch.jit.script_method 2562 def forward(self, x): 2563 return self.tm(torch.mm(x, self.param)) 2564 2565 sm = ScriptMod() 2566 FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph)) 2567 2568 def test_script_inline_trace_multiple_args(self): 2569 class M(torch.nn.Module): 2570 def forward(self, input, input2): 2571 return input + input2 2572 2573 class M2(torch.jit.ScriptModule): 2574 def __init__(self) -> None: 2575 super().__init__() 2576 self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3))) 2577 2578 @torch.jit.script_method 2579 def forward(self, inp): 2580 return self.m(inp, inp) 2581 2582 with torch.jit.optimized_execution(False): 2583 m2 = M2() 2584 m2(torch.zeros(4, 3)) 2585 2586 def test_trace_dict_mix_script(self): 2587 class testB(torch.nn.Module): 2588 def __init__(self) -> None: 2589 super().__init__() 2590 self.linear = torch.nn.Linear(2, 2) 2591 2592 def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor: 2593 output = [] 2594 for j in feature_map.values(): 2595 output.append(self.linear(j[0])) 2596 2597 return torch.stack(output) 2598 2599 class testA(torch.nn.Module): 2600 def __init__(self) -> None: 2601 super().__init__() 2602 self.b = torch.jit.script(testB()) 2603 2604 def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor: 2605 feature_map = {} 2606 for i, j in input_map.items(): 2607 feature_map[i] = [j[0]] 2608 2609 return self.b(feature_map) 2610 2611 input_map = { 2612 "1": [torch.rand(2, 2), torch.rand(2, 2)], 2613 "3": [torch.rand(2, 2), torch.rand(2, 2)], 2614 } 2615 model = testA() 2616 traced_model = torch.jit.trace(model, input_map) 2617 new_input_map = { 2618 "1": [torch.rand(2, 2), torch.randn(2, 2)], 2619 "3": [torch.rand(2, 2), torch.rand(2, 2)], 2620 } 2621 self.assertEqual(model(new_input_map), traced_model(new_input_map)) 2622 2623 def test_trace_script_returning_complex_dict(self): 2624 """Tracing over a script function returning a dictionary should work. 2625 The dictionary can should be able to contain other containers (like a tuple) recursively. 2626 """ 2627 2628 class ReturnsDict(torch.nn.Module): 2629 def forward( 2630 self, 2631 id_score_list: Dict[ 2632 str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 2633 ], 2634 ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 2635 # do some random operations and then return a dict of the same structure 2636 v = id_score_list["1000"] 2637 idx_keys = v[1] - 1500000 2638 weights = v[2] 2639 result = {"1000": (v[0], idx_keys, weights)} 2640 return result 2641 2642 class ChecksDict(torch.nn.Module): 2643 def forward( 2644 self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] 2645 ): 2646 v = input["1000"] 2647 return v[1] + 1 2648 2649 class TestModule(torch.nn.Module): 2650 def __init__(self, checks_dict, returns_dict): 2651 super().__init__() 2652 self.checks_dict = checks_dict 2653 self.returns_dict = returns_dict 2654 2655 def forward( 2656 self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] 2657 ): 2658 foo = self.returns_dict(input) 2659 return self.checks_dict(foo) 2660 2661 input1 = { 2662 "1000": ( 2663 torch.tensor([0]), 2664 torch.tensor([], dtype=torch.int64), 2665 torch.tensor([]), 2666 ) 2667 } 2668 2669 input2 = { 2670 "1000": ( 2671 torch.tensor([0]), 2672 torch.tensor([1500000, 1500004], dtype=torch.int64), 2673 torch.tensor([2.0, 3.0]), 2674 ) 2675 } 2676 2677 checks_dict = torch.jit.script(ChecksDict()) 2678 returns_dict = torch.jit.script(ReturnsDict()) 2679 eager_module = TestModule(checks_dict, returns_dict) 2680 traced_module = torch.jit.trace(eager_module, input1) 2681 self.assertEqual(traced_module(input1), eager_module(input1)) 2682 self.assertEqual(traced_module(input2), eager_module(input2)) 2683 2684 def test_trace_returning_dict_with_tensor_tuples(self): 2685 """Tracing over a module returning a dictionary whose values are tuples of tensors 2686 should work. 2687 """ 2688 2689 class ReturnsDict(torch.nn.Module): 2690 def forward( 2691 self, k: torch.Tensor, v: torch.Tensor 2692 ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: 2693 x = 2 * k 2694 y = 3 * v 2695 result = {"imakey": (x, y)} 2696 return result 2697 2698 class ReturnsBadDict(torch.nn.Module): 2699 def forward( 2700 self, k: torch.Tensor, v: torch.Tensor 2701 ) -> Dict[str, Tuple[torch.Tensor, float]]: 2702 x = 2 * k 2703 result = {"imakey": (x, 1)} 2704 return result 2705 2706 mod = ReturnsDict() 2707 traced_module = torch.jit.trace( 2708 mod, [torch.ones(1), torch.ones(1)], strict=False 2709 ) 2710 out = traced_module(torch.ones(1), torch.ones(1)) 2711 expected = {"imakey": (torch.tensor([2.0]), torch.tensor([3.0]))} 2712 self.assertEqual(out, expected) 2713 2714 with self.assertRaisesRegex( 2715 RuntimeError, "cannot be understood by the tracer, only outputs matching" 2716 ): 2717 mod = ReturnsBadDict() 2718 traced_module = torch.jit.trace( 2719 mod, [torch.ones(1), torch.ones(1)], strict=False 2720 ) 2721 2722 def test_trace_linear(self): 2723 m = torch.nn.Linear(20, 20) 2724 inp = torch.rand([20, 20]) 2725 self.checkTrace(m, (inp,)) 2726 g = torch.jit.trace(m, (inp,)).graph 2727 FileCheck().check("aten::linear").run(g) 2728 2729 def test_traced_module_implements_interface(self): 2730 @torch.jit.interface 2731 class TestModuleInterface(nn.Module): 2732 def forward( 2733 self, first_arg: torch.Tensor, second_arg: torch.Tensor 2734 ) -> torch.Tensor: 2735 pass 2736 2737 make_global(TestModuleInterface) 2738 2739 class TestModule(nn.Module): 2740 def __init__(self) -> None: 2741 super().__init__() 2742 self.conv = nn.Conv2d(1, 1, 3) 2743 2744 def forward( 2745 self, first_arg: torch.Tensor, second_arg: torch.Tensor 2746 ) -> torch.Tensor: 2747 return self.conv(first_arg) + second_arg 2748 2749 def fn_takes_interface(x: TestModuleInterface): 2750 ones = torch.ones(1, 1, 3, 3) 2751 return x.forward(ones, ones) 2752 2753 scripted_test_module = torch.jit.script(TestModule()) 2754 self.checkScript(fn_takes_interface, (scripted_test_module,)) 2755 2756 def test_traced_module_contains_scripted_interface_types(self): 2757 class LeafModule(torch.nn.Module): 2758 def __init__(self) -> None: 2759 super().__init__() 2760 self.weight = torch.nn.Parameter(torch.rand(19)) 2761 2762 def forward(self, input: torch.Tensor): 2763 return input + self.weight 2764 2765 class LowerModuleImpl(torch.nn.Module): 2766 def __init__(self) -> None: 2767 super().__init__() 2768 self.leaf = LeafModule() 2769 2770 def forward(self, input: torch.Tensor) -> torch.Tensor: 2771 return self.leaf(input) 2772 2773 @torch.jit.interface 2774 class LowerModuleInterface(torch.nn.Module): 2775 def forward(self, input: torch.Tensor) -> torch.Tensor: 2776 pass 2777 2778 class MiddleModule(torch.nn.Module): 2779 lower: LowerModuleInterface 2780 2781 def __init__(self, feature_processor_modules=None): 2782 super().__init__() 2783 self.lower = LowerModuleImpl() 2784 2785 def forward(self, input): 2786 return self.lower(input) 2787 2788 class WrapperModule(torch.nn.Module): 2789 def __init__(self, m): 2790 super().__init__() 2791 self.middle = m 2792 2793 def forward(self, input): 2794 return self.middle(input) 2795 2796 class TopModule(torch.nn.Module): 2797 def __init__(self) -> None: 2798 super().__init__() 2799 m = MiddleModule() 2800 m = torch.jit.script(m) 2801 self.sub1 = m 2802 self.sub2 = WrapperModule(m) 2803 2804 def forward(self, input: torch.Tensor): 2805 return self.sub1(input) + self.sub2(input) 2806 2807 top = TopModule() 2808 top_example_input = torch.ones(1) 2809 torch.jit.trace(top, top_example_input) 2810 2811 def test_jit_trace_callfunction_return_shapes(self): 2812 # a torch.jit.script function gets inserted as a CallFunction node 2813 @torch.jit.script 2814 def inner_fn(x): 2815 return torch.cat((x, x)) 2816 2817 def outer_fn(x, y): 2818 return inner_fn(x + y).relu() 2819 2820 x, y = [torch.rand((2, 2), dtype=torch.float) for _ in range(2)] 2821 fn_t = torch.jit.trace(outer_fn, (x, y)) 2822 2823 # expect that the CallFunction node return type has shape information on it. 2824 FileCheck().check("Float").check("4, 2").check("CallFunction").run(fn_t.graph) 2825 for n in fn_t.graph.nodes(): 2826 if n.kind() == "prim::CallFunction": 2827 self.assertTrue(n.output().isCompleteTensor()) 2828