1# Owner(s): ["oncall: jit"] 2 3import operator 4import unittest 5from textwrap import dedent 6from typing import Any, List 7 8import torch 9from torch import nn, Tensor 10from torch.testing import FileCheck 11from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat 12from torch.testing._internal.common_utils import make_tensor 13from torch.testing._internal.jit_utils import execWrapper, JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24# XXX: still in prototype 25class TestSymbolicShapeAnalysis(JitTestCase): 26 def setUp(self): 27 super(JitTestCase, self).setUp() 28 self.prev_symbolic_shapes_test_enabled = ( 29 torch._C._jit_symbolic_shapes_test_mode_enabled() 30 ) 31 torch._C._jit_set_symbolic_shapes_test_mode(True) 32 33 def tearDown(self): 34 torch._C._jit_set_symbolic_shapes_test_mode( 35 self.prev_symbolic_shapes_test_enabled 36 ) 37 38 def test_shape_analysis(self): 39 @torch.jit.script 40 def foo(x, y): 41 return x * y 42 43 inputs = list(foo.graph.inputs()) 44 45 def prop_shapes_on_graph(inp0, inp1): 46 inputs[0].setType(inputs[0].type().with_sizes(inp0)) 47 inputs[1].setType(inputs[1].type().with_sizes(inp1)) 48 torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) 49 50 prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5]) 51 FileCheck().check("1, 7, 6, 5").run(foo.graph) 52 53 # None implicitly creates a new symbolic symbol 54 prop_shapes_on_graph([None, None], [None, None, None]) 55 output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes() 56 inp0_shape = inputs[0].type().symbolic_sizes() 57 inp1_shape = inputs[1].type().symbolic_sizes() 58 59 # output shape dim 0 should be taken from the second inp dim0 60 # other two dims we cannot infer and are given a new symbolic shape 61 self.assertEqual(output_shape[0], inp1_shape[0]) 62 self.assertFalse(output_shape[1] in inp0_shape + inp1_shape) 63 self.assertFalse(output_shape[2] in inp0_shape + inp1_shape) 64 65 # XXX: symbolic shapes are represented with an increasing counter of unique 66 # values, use `_new_symbolic_shape_symbol` api instead of specifying negative 67 # dimensions directly so there is no chance of collision between manual number 68 # and current counter value. 69 sym1 = torch._C._new_symbolic_shape_symbol() 70 sym2 = torch._C._new_symbolic_shape_symbol() 71 sym3 = torch._C._new_symbolic_shape_symbol() 72 prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3]) 73 output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes() 74 self.assertEqual(output_shape[0], sym1) 75 self.assertEqual(output_shape[1], sym2) 76 self.assertEqual(output_shape[2], sym3) 77 78 def test_shared_shape_graph(self): 79 @torch.jit.script 80 def foo(x, y): 81 return x * y, x / y 82 83 mul_node = foo.graph.findNode("aten::mul") 84 div_node = foo.graph.findNode("aten::div") 85 86 mul_graph = torch._C._jit_shape_compute_graph_for_node(mul_node) 87 div_graph = torch._C._jit_shape_compute_graph_for_node(div_node) 88 self.assertIsNotNone(mul_graph) 89 self.assertIs(mul_graph, div_graph) 90 91 def test_write(self): 92 @torch.jit.script 93 def foo(a, b): 94 return a * b 95 96 # broadcast appends cant be removed, so we bail on propagation 97 torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) 98 FileCheck().check("Tensor = aten::mul").run(foo.graph) 99 100 @torch.jit.script 101 def foo(y): 102 x = [1, 2, 3, 4] 103 x[0] = 5 104 return y.view(x) 105 106 torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) 107 FileCheck().check("Tensor = aten::view").run(foo.graph) 108 109 def test_if_propagation(self): 110 @torch.jit.script 111 def foo(i: int, z): 112 x = torch.ones([2, 3, 4, 5]) 113 y = z.view([z.size(i), 3, 2, z.size(i)]) 114 if i == 4: 115 return x 116 else: 117 return y 118 119 torch._C._jit_pass_constant_propagation(foo.graph) 120 torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) 121 view = foo.graph.findNode("aten::view") 122 123 def neg_to_one(li): 124 return [elem if elem >= 0 else -1 for elem in li] 125 126 self.assertEqual( 127 neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1] 128 ) 129 if_out = next(foo.graph.findNode("prim::If").outputs()) 130 self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1]) 131 132 def test_unary_shape_functions(self): 133 unary_ops = [ 134 torch.nn.functional.hardtanh, 135 ] 136 for fn in unary_ops: 137 t = torch.jit.trace(fn, (torch.rand([4, 4]))) 138 ten_input = next(t.graph.inputs()) 139 ten_input.setType(ten_input.type().with_sizes([2, 2])) 140 torch._C._jit_pass_propagate_shapes_on_graph(t.graph) 141 self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2]) 142 143 def test_unary_shape_fns_inplace(self): 144 def mul_inplace(x: torch.Tensor): 145 y = x.mul_(2) 146 return y 147 148 unary_ops = [mul_inplace] 149 for fn in unary_ops: 150 # t = torch.jit.trace(fn, torch.rand([4, 4])) # For some reason tracing is erroring out. 151 t = torch.jit.script(fn) 152 ten_input = next(t.graph.inputs()) 153 ten_input.setType(ten_input.type().with_sizes([2, 2])) 154 torch._C._jit_pass_propagate_shapes_on_graph(t.graph) 155 self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2]) 156 157 def test_binary_shape_functions(self): 158 binary_ops = [ 159 operator.__mul__, 160 operator.__truediv__, 161 operator.__gt__, 162 operator.__add__, 163 ] 164 165 for fn in binary_ops: 166 size_1 = [1, 4, 8] 167 size_2 = [4, 1, 8] 168 t = torch.jit.trace(fn, (torch.rand([4]), torch.rand([4]))) 169 inputs = list(t.graph.inputs()) 170 inputs[0].setType(inputs[0].type().with_sizes(size_1)) 171 inputs[1].setType(inputs[1].type().with_sizes(size_2)) 172 torch._C._jit_pass_propagate_shapes_on_graph(t.graph) 173 self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8]) 174 175 def test_binary_shape_fns_inplace(self): 176 def div_inplace_tensor(x: torch.Tensor, y: torch.Tensor): 177 z = x.div_(y) 178 return z 179 180 def add_inplace_tensor(x: torch.Tensor, y: torch.Tensor): 181 z = x.add_(y) 182 return z 183 184 binary_ops = [ 185 div_inplace_tensor, 186 add_inplace_tensor, 187 ] 188 189 for fn in binary_ops: 190 size_1 = [4, 4, 8] # x (can't broadcast because it's an inplace op) 191 t = torch.jit.script(fn) 192 inputs = list(t.graph.inputs()) 193 inputs[0].setType(inputs[0].type().with_sizes(size_1)) 194 # Intentionally not populate the type of inputs[1] 195 torch._C._jit_pass_propagate_shapes_on_graph(t.graph) 196 self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8]) 197 198 def test_size_and_sizes(self): 199 @torch.jit.script 200 def foo(x, y): 201 return x.view(y.size(0), 8, y.size(-1)) 202 203 @torch.jit.script 204 def foo2(x, y): 205 return x.view(y.size()) 206 207 for graph in [foo.graph, foo2.graph]: 208 inputs = list(graph.inputs()) 209 sym1 = torch._C._new_symbolic_shape_symbol() 210 211 inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1])) 212 torch._C._jit_pass_propagate_shapes_on_graph(graph) 213 self.assertEqual( 214 next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1] 215 ) 216 217 def test_adaptive_avg_pool2d(self): 218 inps = [ 219 [(1, 64, 8, 9), (5, 7)], 220 [(1, 64, 10, 9), (7)], 221 [(1, 64, 10, 9), (5, None)], 222 [(1, 8, 4, 3), (None, None)], 223 [(1, 8, 4, 3), (None, 5)], 224 ] 225 226 for inp in inps: 227 t = torch.randn(*inp[0]) 228 out_size = torch.nn.functional.adaptive_avg_pool2d(t, inp[1]).size() 229 230 def foo(x): 231 return torch.nn.functional.adaptive_avg_pool2d(x, inp[1]) 232 233 fn = torch.jit.trace(foo, (t,)) 234 torch._C._jit_erase_non_input_shape_information(fn.graph) 235 torch._C._jit_pass_peephole(fn.graph) 236 torch._C._jit_pass_constant_propagation(fn.graph) 237 self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True) 238 239 def test_conv_deconv(self): 240 for ( 241 inp_shape, 242 weight_shape, 243 bias, 244 stride, 245 padding, 246 output_padding, 247 dilation, 248 groups, 249 mod, 250 ) in [ 251 ([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d), 252 ( 253 [32, 16, 10], 254 [16, 3, 3], 255 None, 256 2, 257 2, 258 1, 259 1, 260 2, 261 torch.nn.functional.conv_transpose1d, 262 ), 263 ( 264 [1, 32, 5, 10], 265 [30, 16, 3, 3], 266 None, 267 [2, 2], 268 [0, 0], 269 0, 270 1, 271 2, 272 torch.nn.functional.conv2d, 273 ), 274 ( 275 [1, 30, 5, 10], 276 [30, 16, 3, 3], 277 None, 278 [2, 2], 279 [0, 0], 280 0, 281 1, 282 2, 283 torch.nn.functional.conv_transpose2d, 284 ), 285 ( 286 [3, 14, 10, 66, 55], 287 [2, 7, 7, 4, 4], 288 None, 289 1, 290 1, 291 2, 292 1, 293 2, 294 torch.nn.functional.conv3d, 295 ), 296 ( 297 [3, 2, 10, 66, 55], 298 [2, 7, 7, 4, 4], 299 None, 300 1, 301 1, 302 0, 303 1, 304 2, 305 torch.nn.functional.conv_transpose3d, 306 ), 307 ]: 308 inp = torch.rand(inp_shape) 309 weight = torch.rand(weight_shape) 310 if mod in [ 311 torch.nn.functional.conv1d, 312 torch.nn.functional.conv2d, 313 torch.nn.functional.conv3d, 314 ]: 315 res = mod(inp, weight, bias, stride, padding, dilation, groups).size() 316 else: 317 res = mod( 318 inp, weight, bias, stride, padding, output_padding, dilation, groups 319 ).size() 320 321 def foo(inp, weight): 322 if mod in [ 323 torch.nn.functional.conv1d, 324 torch.nn.functional.conv2d, 325 torch.nn.functional.conv3d, 326 ]: 327 return mod(inp, weight, bias, stride, padding, dilation, groups) 328 else: 329 return mod( 330 inp, 331 weight, 332 bias, 333 stride, 334 padding, 335 output_padding, 336 dilation, 337 groups, 338 ) 339 340 fn = torch.jit.trace(foo, (inp, weight)) 341 torch._C._jit_erase_non_input_shape_information(fn.graph) 342 torch._C._jit_pass_peephole(fn.graph) 343 torch._C._jit_pass_constant_propagation(fn.graph) 344 self.checkShapeAnalysis(res, fn.graph, assert_propagation=True) 345 346 def test_arange_shape(self): 347 # no opinfo for tensor constructors 348 inps = [ 349 (10,), 350 (10, 10), 351 (0, 10), 352 (0, 1000), 353 (1, -1, -1), 354 (1, 0, -1), 355 (1, 2, 1), 356 (0.6, 0.89, 0.1), 357 (1, 10, 0.3), 358 (1, 10, 4), 359 (0.6, 0.7, 0.8), 360 (1, 10, 0.3), 361 # (True,), TODO: https://github.com/pytorch/pytorch/issues/63405 362 # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405 363 (0, 5), 364 (0, 5, 2), 365 (0, 5 + 1e-6), 366 (0, 5 - 1e-6), 367 (10, -1 + 1e-6, -1), 368 (10, -1, -1), 369 (10, -1 - 1e-6, -1), 370 ] 371 372 for inp in inps: 373 funcs_template = dedent( 374 """ 375 def func(): 376 return torch.arange({args}) 377 """ 378 ) 379 380 inp_s = str(inp)[1:-1] # remove tuple parens 381 funcs_str = funcs_template.format(args=inp_s) 382 scope = {} 383 execWrapper(funcs_str, globals(), scope) 384 cu = torch.jit.CompilationUnit(funcs_str) 385 self.checkShapeAnalysis( 386 list(cu.func().size()), 387 cu.func.graph, 388 assert_propagation=True, 389 constant_prop=False, 390 ) 391 392 def test_shape_embedding_bag(self): 393 # TODO: merge into opinfos, having difficulties there 394 with torch.no_grad(): 395 396 def make_arg(shape, low=None, high=None): 397 return make_tensor( 398 shape, 399 device="cpu", 400 dtype=torch.int64, 401 low=low, 402 high=high, 403 requires_grad=False, 404 ) 405 406 nn_inps = ( 407 ( 408 make_arg((40,), 0, 9), 409 torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0), 410 ), 411 (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)), 412 (make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)), 413 (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)), 414 (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)), 415 ( 416 make_arg((2,), 0, 1), 417 torch.nn.Embedding.from_pretrained( 418 torch.arange(6.0).view(2, 3), 419 max_norm=2.0, 420 norm_type=0.5, 421 scale_grad_by_freq=False, 422 sparse=True, 423 ), 424 ), 425 ) 426 427 for inp, module in nn_inps: 428 kwargs = { 429 "weight": module.weight.detach(), 430 "padding_idx": module.padding_idx, 431 "max_norm": module.max_norm, 432 "norm_type": module.norm_type, 433 "scale_grad_by_freq": module.scale_grad_by_freq, 434 "sparse": module.sparse, 435 } 436 437 out_size = torch.nn.functional.embedding(inp, **kwargs).size() 438 439 def foo(x): 440 return torch.nn.functional.embedding(inp, **kwargs) 441 442 fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False) 443 444 self.checkShapeAnalysis( 445 out_size, fn.graph, assert_propagation=True, constant_prop=False 446 ) 447 448 def test_shape_concat(self): 449 # TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR 450 sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False) 451 452 class CatMod(nn.Module): 453 __constants__ = ["dim"] 454 455 def __init__(self, dim=0): 456 super().__init__() 457 self.dim = dim 458 459 def forward(self, x, y): 460 return torch.cat([x, y], dim=self.dim) 461 462 for inp in sample_inputs: 463 mod = torch.jit.script(CatMod(**inp.kwargs).eval()) 464 465 args = inp.input 466 467 # This test is hard-coded only to work with two sample inputs 468 # but the OpInfo may have more/less 469 if len(args) != 2: 470 continue 471 472 out_size = mod(*args).size() 473 inps = list(mod.graph.inputs()) 474 inps[1].setType(inps[1].type().with_sizes(args[0].size())) 475 inps[2].setType(inps[2].type().with_sizes(args[1].size())) 476 self.checkShapeAnalysis(out_size, mod.graph, assert_propagation=True) 477 478 def assert_shape_equal_scripted(self, script_fn, given_ins): 479 expected_res = script_fn(*given_ins) 480 g = script_fn.graph 481 graph_ins = list(g.inputs()) 482 self.assertEqual(len(given_ins), len(graph_ins)) 483 for inp, graph_in in zip(given_ins, graph_ins): 484 graph_in.setType(graph_in.type().with_sizes(inp.size())) 485 486 out_sizes = [out.size() for out in expected_res] 487 self.checkShapeAnalysis(out_sizes, g, assert_propagation=True) 488 489 def test_convolution_backward(self): 490 # No opinfos for ops that are not part of the Python API 491 # Also, as the return shapes are the input, weight, and bias shape, there is no point 492 # in a really complicated test 493 494 input = torch.randn( 495 (16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True 496 ) 497 weight = torch.randn( 498 (8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True 499 ) 500 out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu") 501 502 @torch.jit.script 503 def conv_bwd(input, weight, grad): 504 bias_sizes = [ 505 8, 506 ] 507 args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True]) 508 return torch.ops.aten.convolution_backward( 509 grad, input, weight, bias_sizes, *args 510 ) 511 512 self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad)) 513 514 @torch.jit.script 515 def conv_bwd_2(input, weight, grad): 516 bias_sizes = None 517 args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True]) 518 return torch.ops.aten.convolution_backward( 519 grad, input, weight, bias_sizes, *args 520 ) 521 522 self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad)) 523 524 def test_returning_input_symbolic_shapes(self): 525 mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval())) 526 inps = list(mm.graph.inputs()) 527 inps[1].setType(inps[1].type().with_sizes([None, None, None, None])) 528 shape_compute_graph = ( 529 torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph) 530 ) 531 g = shape_compute_graph.partial_eval_shape_graph() 532 # to make into a jit function cant have multiple outputs 533 g.makeMultiOutputIntoTuple() 534 func = torch._C._create_function_from_graph("partial_eval_graph", g) 535 out = func([20, 16, 5, 10]) 536 # first four outputs should be unknown symbolic shapes from input 537 self.assertEqual(out[0:4], [20, 16, 5, 10]) 538 # last two are two new symbolic dims - height and width 539 self.assertEqual(out[4:], list(mm(torch.rand([20, 16, 5, 10])).size()[2:])) 540 541 def test_partial_eval_graph_conv(self): 542 mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval())) 543 shape_compute_graph = ( 544 torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph) 545 ) 546 output_sizes = ( 547 mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes() 548 ) 549 # calculating 0, 2 and 3 index 550 for i in [0, 2, 3]: 551 self.assertTrue(output_sizes[i] < 0) 552 self.assertTrue(output_sizes[1] >= 0) 553 g = shape_compute_graph.partial_eval_shape_graph() 554 # to make into a jit function cant have multiple outputs 555 g.makeMultiOutputIntoTuple() 556 func = torch._C._create_function_from_graph("partial_eval_graph", g) 557 inp = torch.randn(20, 16, 5, 10) 558 output = func([20, 16, 5, 10]) 559 output_eager = list(mm(inp).size()) 560 for o, oe in zip(output, output_eager[0:1] + output_eager[2:]): 561 self.assertEqual(o, oe) 562 563 def checkSymShapeCompute( 564 self, shape_compute_graph, nodes, node_output_sizes, shape_inputs 565 ): 566 g = shape_compute_graph.partial_eval_shape_graph() 567 self.assertTrue(len(list(g.inputs())) == len(shape_inputs)) 568 output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim() 569 # map from sym shape -> index 570 sym_shape_to_index = {} 571 for index, output in enumerate(g.outputs()): 572 sym_shape_to_index[output_sym_map[output]] = index 573 574 g.makeMultiOutputIntoTuple() 575 func = torch._C._create_function_from_graph("partial_eval_graph", g) 576 sym_outputs = func(*shape_inputs) 577 578 for node, output_shape in zip(nodes, node_output_sizes): 579 output_type_sizes = node.output().type().symbolic_sizes() 580 for i, sym_shape in enumerate(output_type_sizes): 581 if sym_shape >= 0: 582 self.assertEqual(sym_shape, output_shape[i]) 583 else: 584 sym_shape_index = sym_shape_to_index[sym_shape] 585 self.assertEqual(sym_outputs[sym_shape_index], output_shape[i]) 586 587 def test_partial_eval_stitching(self): 588 conv1 = torch.nn.Conv2d( 589 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 590 ) 591 max_pool = torch.nn.MaxPool2d( 592 kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 593 ) 594 conv2 = nn.Conv2d( 595 64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False 596 ) 597 598 mod = torch.jit.freeze( 599 torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval()) 600 ) 601 602 conv1_output = conv1(torch.rand(1, 3, 224, 224)) 603 max_pool_output = max_pool(conv1_output) 604 conv2_output = conv2(max_pool_output) 605 606 shape_compute_graph = ( 607 torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) 608 ) 609 nodes = [mod.graph.findNode("aten::max_pool2d")] + list( 610 mod.graph.findAllNodes("aten::conv2d") 611 ) 612 output_shapes = [ 613 max_pool_output.size(), 614 conv1_output.size(), 615 conv2_output.size(), 616 ] 617 self.checkSymShapeCompute( 618 shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],) 619 ) 620 621 def test_refinement_through_graph_stitching(self): 622 class TwoConvs(torch.nn.Module): 623 def __init__(self) -> None: 624 super().__init__() 625 self.conv1 = torch.nn.Conv2d( 626 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 627 ) 628 self.conv2 = torch.nn.Conv2d( 629 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 630 ) 631 632 def forward(self, x): 633 a = self.conv1(x) 634 b = self.conv2(x) 635 return a + b 636 637 mod = torch.jit.freeze(torch.jit.script(TwoConvs()).eval()) 638 inp_tensor = list(mod.graph.inputs())[1] 639 inp_tensor.setType(inp_tensor.type().with_sizes([None, None, None, None])) 640 torch._C._jit_pass_propagate_shapes_on_graph(mod.graph) 641 outs = list(next(mod.graph.outputs()).node().inputs()) 642 out1 = outs[0].type().symbolic_sizes() 643 out2 = outs[1].type().symbolic_sizes() 644 self.assertTrue(out1[2] != out2[2]) 645 self.assertTrue(out1[3] != out2[3]) 646 # by joining partial eval graphs of both convs we are able to recognize the output shapes 647 # are equivalent 648 torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) 649 out1 = outs[0].type().symbolic_sizes() 650 out2 = outs[1].type().symbolic_sizes() 651 self.assertEqual(out1, out2) 652 653 def test_stitching_multi_output(self): 654 max_pool = torch.nn.MaxPool2d( 655 kernel_size=3, 656 stride=2, 657 padding=1, 658 dilation=1, 659 ceil_mode=False, 660 return_indices=True, 661 ) 662 tensor = torch.rand(1, 3, 224, 224) 663 mod = torch.jit.trace(max_pool, (tensor,)) 664 mod = torch.jit.freeze(mod.eval()) 665 inp = list(mod.graph.inputs())[1] 666 inp.setType(inp.type().with_sizes([None, None, None, None])) 667 output_tensor = list(mod(tensor)[0].size()) 668 self.run_pass("lower_all_tuples", mod.graph) 669 shape_compute_graph = ( 670 torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) 671 ) 672 max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices") 673 outs = list(max_pool_node.outputs()) 674 self.assertEqual( 675 outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes() 676 ) 677 g = shape_compute_graph.partial_eval_shape_graph() 678 # to make into a jit function cant have multiple outputs 679 g.makeMultiOutputIntoTuple() 680 func = torch._C._create_function_from_graph("partial_eval_graph", g) 681 mapping = shape_compute_graph.graph_output_to_symbolic_shape_dim() 682 output_shape = func(tensor.size()) 683 # the first 4 dims are input sym dimensions, then the , 684 self.assertEqual(list(output_shape[0:4]), list(tensor.size())) 685 self.assertEqual(list(output_shape[4:]), output_tensor[2:]) 686 687 def test_sym_ir_parsing(self): 688 graph_str1 = """graph(%x.1 : Float(SS(-2), SS(-3))): 689 %3 : int = prim::Constant[value=1]() 690 %4 : Tensor = aten::add(%x.1, %x.1, %3) 691 return (%4)""" 692 g = torch._C.parse_ir(graph_str1) 693 inp = next(g.inputs()) 694 out = inp.type().symbolic_sizes() 695 self.assertEqual(out, [-2, -3]) 696 697 def test_stitching_concat(self): 698 @torch.jit.script 699 def foo1(a, b, x, y): 700 return (a / b) + torch.cat([x, y]) 701 702 @torch.jit.script 703 def foo2(a, b, x, y): 704 return (a / b) + torch.cat([x, y], dim=-2) 705 706 for foo in [foo1, foo2]: 707 g = foo.graph 708 for inp in foo.graph.inputs(): 709 inp.setType(inp.type().with_sizes([None, None])) 710 711 shape_compute_graph = ( 712 torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute( 713 foo.graph 714 ) 715 ) 716 nodes = ( 717 [g.findNode("aten::div")] 718 + [g.findNode("aten::add")] 719 + [g.findNode("aten::cat")] 720 ) 721 722 inps = [1, 10], [20, 10], [15, 1], [5, 1] 723 output_shapes = [[20, 10], [20, 10], [20, 1]] 724 725 self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps) 726 727 @unittest.skipIf( 728 not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python" 729 ) 730 def test_shape_function_includes(self): 731 inp_shape = [1, 16, 5, 10] 732 weight_shape = [33, 16, 3, 3] 733 bias = None 734 stride = [2, 2] 735 padding = [0, 0] 736 dilation = [1, 1] 737 groups = 1 738 res = torch.jit._shapes.conv2d( 739 inp_shape, weight_shape, bias, stride, padding, dilation, groups 740 ) 741 self.assertEqual(res, [1, 33, 2, 4]) 742 743 m1_shape = [10, 20] 744 m2_shape = [20, 10] 745 res = torch.jit._shapes.matmul(m1_shape, m2_shape) 746 self.assertEqual(res, [10, 10]) 747 748 def test_register_function_error_checking(self): 749 # this will error before registering on global map, so 750 # no issue in overwriting schema mappings 751 @torch.jit.script 752 def foo(x, y): 753 return x + y 754 755 node = foo.graph.findNode("aten::add") 756 757 @torch.jit.script 758 def wrong_input_types(x, y): 759 x: List[int] = [] 760 return x 761 762 with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"): 763 torch._C._jit_register_shape_compute_graph_for_node( 764 node, wrong_input_types.graph 765 ) 766 767 @torch.jit.script 768 def wrong_output_types(x: List[int], y: List[int]): 769 x: List[Tensor] = [] 770 return x 771 772 with self.assertRaisesRegex(RuntimeError, "but got graph_type"): 773 torch._C._jit_register_shape_compute_graph_for_node( 774 node, wrong_output_types.graph 775 ) 776 777 @torch.jit.script 778 def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any): 779 x: List[int] = [] 780 return x 781 782 with self.assertRaises(RuntimeError) as error: 783 torch._C._jit_register_shape_compute_graph_for_node( 784 node, too_many_inputs.graph 785 ) 786 787 self.assertTrue("fewer arguments than schema" in str(error.exception)) 788 789 def test_cross_entropy_loss(self): 790 @torch.jit.script 791 def foo(x, y): 792 return torch.ops.aten.cross_entropy_loss(x, y, reduction=0) 793 794 inputs = list(foo.graph.inputs()) 795 inputs[0].setType(inputs[0].type().with_sizes([8, 2])) 796 inputs[1].setType( 797 inputs[1] 798 .type() 799 .with_sizes( 800 [ 801 8, 802 ] 803 ) 804 ) 805 torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) 806 self.assertEqual( 807 next(foo.graph.outputs()).type().sizes(), 808 [ 809 8, 810 ], 811 ) 812 813 def test_squeeze_dims(self): 814 @torch.jit.script 815 def foo(x): 816 return torch.ops.aten.squeeze(x, dim=0) 817 818 input = next(foo.graph.inputs()) 819 input.setType(input.type().with_sizes([1, 5, 8])) 820 torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) 821 self.assertEqual(next(foo.graph.outputs()).type().symbolic_sizes(), [5, 8]) 822