xref: /aosp_15_r20/external/pytorch/test/jit/test_symbolic_shape_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import operator
4import unittest
5from textwrap import dedent
6from typing import Any, List
7
8import torch
9from torch import nn, Tensor
10from torch.testing import FileCheck
11from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
12from torch.testing._internal.common_utils import make_tensor
13from torch.testing._internal.jit_utils import execWrapper, JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24# XXX: still in prototype
25class TestSymbolicShapeAnalysis(JitTestCase):
26    def setUp(self):
27        super(JitTestCase, self).setUp()
28        self.prev_symbolic_shapes_test_enabled = (
29            torch._C._jit_symbolic_shapes_test_mode_enabled()
30        )
31        torch._C._jit_set_symbolic_shapes_test_mode(True)
32
33    def tearDown(self):
34        torch._C._jit_set_symbolic_shapes_test_mode(
35            self.prev_symbolic_shapes_test_enabled
36        )
37
38    def test_shape_analysis(self):
39        @torch.jit.script
40        def foo(x, y):
41            return x * y
42
43        inputs = list(foo.graph.inputs())
44
45        def prop_shapes_on_graph(inp0, inp1):
46            inputs[0].setType(inputs[0].type().with_sizes(inp0))
47            inputs[1].setType(inputs[1].type().with_sizes(inp1))
48            torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
49
50        prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5])
51        FileCheck().check("1, 7, 6, 5").run(foo.graph)
52
53        # None implicitly creates a new symbolic symbol
54        prop_shapes_on_graph([None, None], [None, None, None])
55        output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes()
56        inp0_shape = inputs[0].type().symbolic_sizes()
57        inp1_shape = inputs[1].type().symbolic_sizes()
58
59        # output shape dim 0 should be taken from the second inp dim0
60        # other two dims we cannot infer and are given a new symbolic shape
61        self.assertEqual(output_shape[0], inp1_shape[0])
62        self.assertFalse(output_shape[1] in inp0_shape + inp1_shape)
63        self.assertFalse(output_shape[2] in inp0_shape + inp1_shape)
64
65        # XXX: symbolic shapes are represented with an increasing counter of unique
66        # values, use `_new_symbolic_shape_symbol` api instead of specifying negative
67        # dimensions directly so there is no chance of collision between manual number
68        # and current counter value.
69        sym1 = torch._C._new_symbolic_shape_symbol()
70        sym2 = torch._C._new_symbolic_shape_symbol()
71        sym3 = torch._C._new_symbolic_shape_symbol()
72        prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3])
73        output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes()
74        self.assertEqual(output_shape[0], sym1)
75        self.assertEqual(output_shape[1], sym2)
76        self.assertEqual(output_shape[2], sym3)
77
78    def test_shared_shape_graph(self):
79        @torch.jit.script
80        def foo(x, y):
81            return x * y, x / y
82
83        mul_node = foo.graph.findNode("aten::mul")
84        div_node = foo.graph.findNode("aten::div")
85
86        mul_graph = torch._C._jit_shape_compute_graph_for_node(mul_node)
87        div_graph = torch._C._jit_shape_compute_graph_for_node(div_node)
88        self.assertIsNotNone(mul_graph)
89        self.assertIs(mul_graph, div_graph)
90
91    def test_write(self):
92        @torch.jit.script
93        def foo(a, b):
94            return a * b
95
96        # broadcast appends cant be removed, so we bail on propagation
97        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
98        FileCheck().check("Tensor = aten::mul").run(foo.graph)
99
100        @torch.jit.script
101        def foo(y):
102            x = [1, 2, 3, 4]
103            x[0] = 5
104            return y.view(x)
105
106        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
107        FileCheck().check("Tensor = aten::view").run(foo.graph)
108
109    def test_if_propagation(self):
110        @torch.jit.script
111        def foo(i: int, z):
112            x = torch.ones([2, 3, 4, 5])
113            y = z.view([z.size(i), 3, 2, z.size(i)])
114            if i == 4:
115                return x
116            else:
117                return y
118
119        torch._C._jit_pass_constant_propagation(foo.graph)
120        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
121        view = foo.graph.findNode("aten::view")
122
123        def neg_to_one(li):
124            return [elem if elem >= 0 else -1 for elem in li]
125
126        self.assertEqual(
127            neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1]
128        )
129        if_out = next(foo.graph.findNode("prim::If").outputs())
130        self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1])
131
132    def test_unary_shape_functions(self):
133        unary_ops = [
134            torch.nn.functional.hardtanh,
135        ]
136        for fn in unary_ops:
137            t = torch.jit.trace(fn, (torch.rand([4, 4])))
138            ten_input = next(t.graph.inputs())
139            ten_input.setType(ten_input.type().with_sizes([2, 2]))
140            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
141            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2])
142
143    def test_unary_shape_fns_inplace(self):
144        def mul_inplace(x: torch.Tensor):
145            y = x.mul_(2)
146            return y
147
148        unary_ops = [mul_inplace]
149        for fn in unary_ops:
150            # t = torch.jit.trace(fn, torch.rand([4, 4]))  # For some reason tracing is erroring out.
151            t = torch.jit.script(fn)
152            ten_input = next(t.graph.inputs())
153            ten_input.setType(ten_input.type().with_sizes([2, 2]))
154            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
155            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2])
156
157    def test_binary_shape_functions(self):
158        binary_ops = [
159            operator.__mul__,
160            operator.__truediv__,
161            operator.__gt__,
162            operator.__add__,
163        ]
164
165        for fn in binary_ops:
166            size_1 = [1, 4, 8]
167            size_2 = [4, 1, 8]
168            t = torch.jit.trace(fn, (torch.rand([4]), torch.rand([4])))
169            inputs = list(t.graph.inputs())
170            inputs[0].setType(inputs[0].type().with_sizes(size_1))
171            inputs[1].setType(inputs[1].type().with_sizes(size_2))
172            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
173            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8])
174
175    def test_binary_shape_fns_inplace(self):
176        def div_inplace_tensor(x: torch.Tensor, y: torch.Tensor):
177            z = x.div_(y)
178            return z
179
180        def add_inplace_tensor(x: torch.Tensor, y: torch.Tensor):
181            z = x.add_(y)
182            return z
183
184        binary_ops = [
185            div_inplace_tensor,
186            add_inplace_tensor,
187        ]
188
189        for fn in binary_ops:
190            size_1 = [4, 4, 8]  # x (can't broadcast because it's an inplace op)
191            t = torch.jit.script(fn)
192            inputs = list(t.graph.inputs())
193            inputs[0].setType(inputs[0].type().with_sizes(size_1))
194            # Intentionally not populate the type of inputs[1]
195            torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
196            self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8])
197
198    def test_size_and_sizes(self):
199        @torch.jit.script
200        def foo(x, y):
201            return x.view(y.size(0), 8, y.size(-1))
202
203        @torch.jit.script
204        def foo2(x, y):
205            return x.view(y.size())
206
207        for graph in [foo.graph, foo2.graph]:
208            inputs = list(graph.inputs())
209            sym1 = torch._C._new_symbolic_shape_symbol()
210
211            inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
212            torch._C._jit_pass_propagate_shapes_on_graph(graph)
213            self.assertEqual(
214                next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]
215            )
216
217    def test_adaptive_avg_pool2d(self):
218        inps = [
219            [(1, 64, 8, 9), (5, 7)],
220            [(1, 64, 10, 9), (7)],
221            [(1, 64, 10, 9), (5, None)],
222            [(1, 8, 4, 3), (None, None)],
223            [(1, 8, 4, 3), (None, 5)],
224        ]
225
226        for inp in inps:
227            t = torch.randn(*inp[0])
228            out_size = torch.nn.functional.adaptive_avg_pool2d(t, inp[1]).size()
229
230            def foo(x):
231                return torch.nn.functional.adaptive_avg_pool2d(x, inp[1])
232
233            fn = torch.jit.trace(foo, (t,))
234            torch._C._jit_erase_non_input_shape_information(fn.graph)
235            torch._C._jit_pass_peephole(fn.graph)
236            torch._C._jit_pass_constant_propagation(fn.graph)
237            self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
238
239    def test_conv_deconv(self):
240        for (
241            inp_shape,
242            weight_shape,
243            bias,
244            stride,
245            padding,
246            output_padding,
247            dilation,
248            groups,
249            mod,
250        ) in [
251            ([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d),
252            (
253                [32, 16, 10],
254                [16, 3, 3],
255                None,
256                2,
257                2,
258                1,
259                1,
260                2,
261                torch.nn.functional.conv_transpose1d,
262            ),
263            (
264                [1, 32, 5, 10],
265                [30, 16, 3, 3],
266                None,
267                [2, 2],
268                [0, 0],
269                0,
270                1,
271                2,
272                torch.nn.functional.conv2d,
273            ),
274            (
275                [1, 30, 5, 10],
276                [30, 16, 3, 3],
277                None,
278                [2, 2],
279                [0, 0],
280                0,
281                1,
282                2,
283                torch.nn.functional.conv_transpose2d,
284            ),
285            (
286                [3, 14, 10, 66, 55],
287                [2, 7, 7, 4, 4],
288                None,
289                1,
290                1,
291                2,
292                1,
293                2,
294                torch.nn.functional.conv3d,
295            ),
296            (
297                [3, 2, 10, 66, 55],
298                [2, 7, 7, 4, 4],
299                None,
300                1,
301                1,
302                0,
303                1,
304                2,
305                torch.nn.functional.conv_transpose3d,
306            ),
307        ]:
308            inp = torch.rand(inp_shape)
309            weight = torch.rand(weight_shape)
310            if mod in [
311                torch.nn.functional.conv1d,
312                torch.nn.functional.conv2d,
313                torch.nn.functional.conv3d,
314            ]:
315                res = mod(inp, weight, bias, stride, padding, dilation, groups).size()
316            else:
317                res = mod(
318                    inp, weight, bias, stride, padding, output_padding, dilation, groups
319                ).size()
320
321            def foo(inp, weight):
322                if mod in [
323                    torch.nn.functional.conv1d,
324                    torch.nn.functional.conv2d,
325                    torch.nn.functional.conv3d,
326                ]:
327                    return mod(inp, weight, bias, stride, padding, dilation, groups)
328                else:
329                    return mod(
330                        inp,
331                        weight,
332                        bias,
333                        stride,
334                        padding,
335                        output_padding,
336                        dilation,
337                        groups,
338                    )
339
340            fn = torch.jit.trace(foo, (inp, weight))
341            torch._C._jit_erase_non_input_shape_information(fn.graph)
342            torch._C._jit_pass_peephole(fn.graph)
343            torch._C._jit_pass_constant_propagation(fn.graph)
344            self.checkShapeAnalysis(res, fn.graph, assert_propagation=True)
345
346    def test_arange_shape(self):
347        # no opinfo for tensor constructors
348        inps = [
349            (10,),
350            (10, 10),
351            (0, 10),
352            (0, 1000),
353            (1, -1, -1),
354            (1, 0, -1),
355            (1, 2, 1),
356            (0.6, 0.89, 0.1),
357            (1, 10, 0.3),
358            (1, 10, 4),
359            (0.6, 0.7, 0.8),
360            (1, 10, 0.3),
361            # (True,),  TODO: https://github.com/pytorch/pytorch/issues/63405
362            # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405
363            (0, 5),
364            (0, 5, 2),
365            (0, 5 + 1e-6),
366            (0, 5 - 1e-6),
367            (10, -1 + 1e-6, -1),
368            (10, -1, -1),
369            (10, -1 - 1e-6, -1),
370        ]
371
372        for inp in inps:
373            funcs_template = dedent(
374                """
375            def func():
376                return torch.arange({args})
377            """
378            )
379
380            inp_s = str(inp)[1:-1]  # remove tuple parens
381            funcs_str = funcs_template.format(args=inp_s)
382            scope = {}
383            execWrapper(funcs_str, globals(), scope)
384            cu = torch.jit.CompilationUnit(funcs_str)
385            self.checkShapeAnalysis(
386                list(cu.func().size()),
387                cu.func.graph,
388                assert_propagation=True,
389                constant_prop=False,
390            )
391
392    def test_shape_embedding_bag(self):
393        # TODO: merge into opinfos, having difficulties there
394        with torch.no_grad():
395
396            def make_arg(shape, low=None, high=None):
397                return make_tensor(
398                    shape,
399                    device="cpu",
400                    dtype=torch.int64,
401                    low=low,
402                    high=high,
403                    requires_grad=False,
404                )
405
406            nn_inps = (
407                (
408                    make_arg((40,), 0, 9),
409                    torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0),
410                ),
411                (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)),
412                (make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)),
413                (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)),
414                (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)),
415                (
416                    make_arg((2,), 0, 1),
417                    torch.nn.Embedding.from_pretrained(
418                        torch.arange(6.0).view(2, 3),
419                        max_norm=2.0,
420                        norm_type=0.5,
421                        scale_grad_by_freq=False,
422                        sparse=True,
423                    ),
424                ),
425            )
426
427            for inp, module in nn_inps:
428                kwargs = {
429                    "weight": module.weight.detach(),
430                    "padding_idx": module.padding_idx,
431                    "max_norm": module.max_norm,
432                    "norm_type": module.norm_type,
433                    "scale_grad_by_freq": module.scale_grad_by_freq,
434                    "sparse": module.sparse,
435                }
436
437                out_size = torch.nn.functional.embedding(inp, **kwargs).size()
438
439                def foo(x):
440                    return torch.nn.functional.embedding(inp, **kwargs)
441
442                fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False)
443
444                self.checkShapeAnalysis(
445                    out_size, fn.graph, assert_propagation=True, constant_prop=False
446                )
447
448    def test_shape_concat(self):
449        # TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR
450        sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False)
451
452        class CatMod(nn.Module):
453            __constants__ = ["dim"]
454
455            def __init__(self, dim=0):
456                super().__init__()
457                self.dim = dim
458
459            def forward(self, x, y):
460                return torch.cat([x, y], dim=self.dim)
461
462        for inp in sample_inputs:
463            mod = torch.jit.script(CatMod(**inp.kwargs).eval())
464
465            args = inp.input
466
467            # This test is hard-coded only to work with two sample inputs
468            # but the OpInfo may have more/less
469            if len(args) != 2:
470                continue
471
472            out_size = mod(*args).size()
473            inps = list(mod.graph.inputs())
474            inps[1].setType(inps[1].type().with_sizes(args[0].size()))
475            inps[2].setType(inps[2].type().with_sizes(args[1].size()))
476            self.checkShapeAnalysis(out_size, mod.graph, assert_propagation=True)
477
478    def assert_shape_equal_scripted(self, script_fn, given_ins):
479        expected_res = script_fn(*given_ins)
480        g = script_fn.graph
481        graph_ins = list(g.inputs())
482        self.assertEqual(len(given_ins), len(graph_ins))
483        for inp, graph_in in zip(given_ins, graph_ins):
484            graph_in.setType(graph_in.type().with_sizes(inp.size()))
485
486        out_sizes = [out.size() for out in expected_res]
487        self.checkShapeAnalysis(out_sizes, g, assert_propagation=True)
488
489    def test_convolution_backward(self):
490        # No opinfos for ops that are not part of the Python API
491        # Also, as the return shapes are the input, weight, and bias shape, there is no point
492        # in a really complicated test
493
494        input = torch.randn(
495            (16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True
496        )
497        weight = torch.randn(
498            (8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True
499        )
500        out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu")
501
502        @torch.jit.script
503        def conv_bwd(input, weight, grad):
504            bias_sizes = [
505                8,
506            ]
507            args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
508            return torch.ops.aten.convolution_backward(
509                grad, input, weight, bias_sizes, *args
510            )
511
512        self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad))
513
514        @torch.jit.script
515        def conv_bwd_2(input, weight, grad):
516            bias_sizes = None
517            args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
518            return torch.ops.aten.convolution_backward(
519                grad, input, weight, bias_sizes, *args
520            )
521
522        self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
523
524    def test_returning_input_symbolic_shapes(self):
525        mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
526        inps = list(mm.graph.inputs())
527        inps[1].setType(inps[1].type().with_sizes([None, None, None, None]))
528        shape_compute_graph = (
529            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
530        )
531        g = shape_compute_graph.partial_eval_shape_graph()
532        # to make into a jit function cant have multiple outputs
533        g.makeMultiOutputIntoTuple()
534        func = torch._C._create_function_from_graph("partial_eval_graph", g)
535        out = func([20, 16, 5, 10])
536        # first four outputs should be unknown symbolic shapes from input
537        self.assertEqual(out[0:4], [20, 16, 5, 10])
538        # last two are two new symbolic dims - height and width
539        self.assertEqual(out[4:], list(mm(torch.rand([20, 16, 5, 10])).size()[2:]))
540
541    def test_partial_eval_graph_conv(self):
542        mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
543        shape_compute_graph = (
544            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
545        )
546        output_sizes = (
547            mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes()
548        )
549        # calculating 0, 2 and 3 index
550        for i in [0, 2, 3]:
551            self.assertTrue(output_sizes[i] < 0)
552        self.assertTrue(output_sizes[1] >= 0)
553        g = shape_compute_graph.partial_eval_shape_graph()
554        # to make into a jit function cant have multiple outputs
555        g.makeMultiOutputIntoTuple()
556        func = torch._C._create_function_from_graph("partial_eval_graph", g)
557        inp = torch.randn(20, 16, 5, 10)
558        output = func([20, 16, 5, 10])
559        output_eager = list(mm(inp).size())
560        for o, oe in zip(output, output_eager[0:1] + output_eager[2:]):
561            self.assertEqual(o, oe)
562
563    def checkSymShapeCompute(
564        self, shape_compute_graph, nodes, node_output_sizes, shape_inputs
565    ):
566        g = shape_compute_graph.partial_eval_shape_graph()
567        self.assertTrue(len(list(g.inputs())) == len(shape_inputs))
568        output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim()
569        # map from sym shape -> index
570        sym_shape_to_index = {}
571        for index, output in enumerate(g.outputs()):
572            sym_shape_to_index[output_sym_map[output]] = index
573
574        g.makeMultiOutputIntoTuple()
575        func = torch._C._create_function_from_graph("partial_eval_graph", g)
576        sym_outputs = func(*shape_inputs)
577
578        for node, output_shape in zip(nodes, node_output_sizes):
579            output_type_sizes = node.output().type().symbolic_sizes()
580            for i, sym_shape in enumerate(output_type_sizes):
581                if sym_shape >= 0:
582                    self.assertEqual(sym_shape, output_shape[i])
583                else:
584                    sym_shape_index = sym_shape_to_index[sym_shape]
585                    self.assertEqual(sym_outputs[sym_shape_index], output_shape[i])
586
587    def test_partial_eval_stitching(self):
588        conv1 = torch.nn.Conv2d(
589            3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
590        )
591        max_pool = torch.nn.MaxPool2d(
592            kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
593        )
594        conv2 = nn.Conv2d(
595            64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
596        )
597
598        mod = torch.jit.freeze(
599            torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval())
600        )
601
602        conv1_output = conv1(torch.rand(1, 3, 224, 224))
603        max_pool_output = max_pool(conv1_output)
604        conv2_output = conv2(max_pool_output)
605
606        shape_compute_graph = (
607            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
608        )
609        nodes = [mod.graph.findNode("aten::max_pool2d")] + list(
610            mod.graph.findAllNodes("aten::conv2d")
611        )
612        output_shapes = [
613            max_pool_output.size(),
614            conv1_output.size(),
615            conv2_output.size(),
616        ]
617        self.checkSymShapeCompute(
618            shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],)
619        )
620
621    def test_refinement_through_graph_stitching(self):
622        class TwoConvs(torch.nn.Module):
623            def __init__(self) -> None:
624                super().__init__()
625                self.conv1 = torch.nn.Conv2d(
626                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
627                )
628                self.conv2 = torch.nn.Conv2d(
629                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
630                )
631
632            def forward(self, x):
633                a = self.conv1(x)
634                b = self.conv2(x)
635                return a + b
636
637        mod = torch.jit.freeze(torch.jit.script(TwoConvs()).eval())
638        inp_tensor = list(mod.graph.inputs())[1]
639        inp_tensor.setType(inp_tensor.type().with_sizes([None, None, None, None]))
640        torch._C._jit_pass_propagate_shapes_on_graph(mod.graph)
641        outs = list(next(mod.graph.outputs()).node().inputs())
642        out1 = outs[0].type().symbolic_sizes()
643        out2 = outs[1].type().symbolic_sizes()
644        self.assertTrue(out1[2] != out2[2])
645        self.assertTrue(out1[3] != out2[3])
646        # by joining partial eval graphs of both convs we are able to recognize the output shapes
647        # are equivalent
648        torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
649        out1 = outs[0].type().symbolic_sizes()
650        out2 = outs[1].type().symbolic_sizes()
651        self.assertEqual(out1, out2)
652
653    def test_stitching_multi_output(self):
654        max_pool = torch.nn.MaxPool2d(
655            kernel_size=3,
656            stride=2,
657            padding=1,
658            dilation=1,
659            ceil_mode=False,
660            return_indices=True,
661        )
662        tensor = torch.rand(1, 3, 224, 224)
663        mod = torch.jit.trace(max_pool, (tensor,))
664        mod = torch.jit.freeze(mod.eval())
665        inp = list(mod.graph.inputs())[1]
666        inp.setType(inp.type().with_sizes([None, None, None, None]))
667        output_tensor = list(mod(tensor)[0].size())
668        self.run_pass("lower_all_tuples", mod.graph)
669        shape_compute_graph = (
670            torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
671        )
672        max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices")
673        outs = list(max_pool_node.outputs())
674        self.assertEqual(
675            outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes()
676        )
677        g = shape_compute_graph.partial_eval_shape_graph()
678        # to make into a jit function cant have multiple outputs
679        g.makeMultiOutputIntoTuple()
680        func = torch._C._create_function_from_graph("partial_eval_graph", g)
681        mapping = shape_compute_graph.graph_output_to_symbolic_shape_dim()
682        output_shape = func(tensor.size())
683        # the first 4 dims are input sym dimensions, then the ,
684        self.assertEqual(list(output_shape[0:4]), list(tensor.size()))
685        self.assertEqual(list(output_shape[4:]), output_tensor[2:])
686
687    def test_sym_ir_parsing(self):
688        graph_str1 = """graph(%x.1 : Float(SS(-2), SS(-3))):
689                        %3 : int = prim::Constant[value=1]()
690                        %4 : Tensor = aten::add(%x.1, %x.1, %3)
691                        return (%4)"""
692        g = torch._C.parse_ir(graph_str1)
693        inp = next(g.inputs())
694        out = inp.type().symbolic_sizes()
695        self.assertEqual(out, [-2, -3])
696
697    def test_stitching_concat(self):
698        @torch.jit.script
699        def foo1(a, b, x, y):
700            return (a / b) + torch.cat([x, y])
701
702        @torch.jit.script
703        def foo2(a, b, x, y):
704            return (a / b) + torch.cat([x, y], dim=-2)
705
706        for foo in [foo1, foo2]:
707            g = foo.graph
708            for inp in foo.graph.inputs():
709                inp.setType(inp.type().with_sizes([None, None]))
710
711            shape_compute_graph = (
712                torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(
713                    foo.graph
714                )
715            )
716            nodes = (
717                [g.findNode("aten::div")]
718                + [g.findNode("aten::add")]
719                + [g.findNode("aten::cat")]
720            )
721
722            inps = [1, 10], [20, 10], [15, 1], [5, 1]
723            output_shapes = [[20, 10], [20, 10], [20, 1]]
724
725            self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)
726
727    @unittest.skipIf(
728        not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python"
729    )
730    def test_shape_function_includes(self):
731        inp_shape = [1, 16, 5, 10]
732        weight_shape = [33, 16, 3, 3]
733        bias = None
734        stride = [2, 2]
735        padding = [0, 0]
736        dilation = [1, 1]
737        groups = 1
738        res = torch.jit._shapes.conv2d(
739            inp_shape, weight_shape, bias, stride, padding, dilation, groups
740        )
741        self.assertEqual(res, [1, 33, 2, 4])
742
743        m1_shape = [10, 20]
744        m2_shape = [20, 10]
745        res = torch.jit._shapes.matmul(m1_shape, m2_shape)
746        self.assertEqual(res, [10, 10])
747
748    def test_register_function_error_checking(self):
749        # this will error before registering on global map, so
750        # no issue in overwriting schema mappings
751        @torch.jit.script
752        def foo(x, y):
753            return x + y
754
755        node = foo.graph.findNode("aten::add")
756
757        @torch.jit.script
758        def wrong_input_types(x, y):
759            x: List[int] = []
760            return x
761
762        with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"):
763            torch._C._jit_register_shape_compute_graph_for_node(
764                node, wrong_input_types.graph
765            )
766
767        @torch.jit.script
768        def wrong_output_types(x: List[int], y: List[int]):
769            x: List[Tensor] = []
770            return x
771
772        with self.assertRaisesRegex(RuntimeError, "but got graph_type"):
773            torch._C._jit_register_shape_compute_graph_for_node(
774                node, wrong_output_types.graph
775            )
776
777        @torch.jit.script
778        def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any):
779            x: List[int] = []
780            return x
781
782        with self.assertRaises(RuntimeError) as error:
783            torch._C._jit_register_shape_compute_graph_for_node(
784                node, too_many_inputs.graph
785            )
786
787        self.assertTrue("fewer arguments than schema" in str(error.exception))
788
789    def test_cross_entropy_loss(self):
790        @torch.jit.script
791        def foo(x, y):
792            return torch.ops.aten.cross_entropy_loss(x, y, reduction=0)
793
794        inputs = list(foo.graph.inputs())
795        inputs[0].setType(inputs[0].type().with_sizes([8, 2]))
796        inputs[1].setType(
797            inputs[1]
798            .type()
799            .with_sizes(
800                [
801                    8,
802                ]
803            )
804        )
805        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
806        self.assertEqual(
807            next(foo.graph.outputs()).type().sizes(),
808            [
809                8,
810            ],
811        )
812
813    def test_squeeze_dims(self):
814        @torch.jit.script
815        def foo(x):
816            return torch.ops.aten.squeeze(x, dim=0)
817
818        input = next(foo.graph.inputs())
819        input.setType(input.type().with_sizes([1, 5, 8]))
820        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
821        self.assertEqual(next(foo.graph.outputs()).type().symbolic_sizes(), [5, 8])
822