xref: /aosp_15_r20/external/pytorch/test/jit/test_tracer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import copy
4import io
5import os
6import sys
7import unittest
8
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12from torch.autograd import Function, Variable
13from torch.testing import FileCheck
14
15
16# Make the helper files in test/ importable
17pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
18sys.path.append(pytorch_test_dir)
19import warnings
20
21# Standard library
22from collections import namedtuple
23from itertools import chain
24from typing import Dict, List, Optional, Tuple
25
26from torch import Tensor
27from torch.testing._internal.common_cuda import with_tf32_off
28from torch.testing._internal.common_utils import (
29    enable_profiling_mode_for_profiling_tests,
30    IS_SANDCASTLE,
31    skipIfCompiledWithoutNumpy,
32    skipIfCrossRef,
33    skipIfTorchDynamo,
34    suppress_warnings,
35    TemporaryFileName,
36)
37from torch.testing._internal.jit_utils import (
38    _tmp_donotuse_dont_inline_everything,
39    _trace,
40    enable_cpu_fuser,
41    JitTestCase,
42    make_global,
43    RUN_CUDA,
44    RUN_CUDA_MULTI_GPU,
45)
46
47
48if __name__ == "__main__":
49    raise RuntimeError(
50        "This test file is not meant to be run directly, use:\n\n"
51        "\tpython test/test_jit.py TESTNAME\n\n"
52        "instead."
53    )
54
55
56@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
57class TestTracer(JitTestCase):
58    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
59    def test_large_nbr_kernel_args(self):
60        class Recurrence(nn.Module):
61            def __init__(self, seq_len):
62                super().__init__()
63                self.seq_len = seq_len
64
65            def forward(self, input):
66                input = input.transpose(0, 1)
67
68                # Main loop
69                output = []
70                for i in range(self.seq_len):
71                    b = input[i] * 2
72                    output.append(b)
73
74                output = torch.cat(output, 0).view(input.size(0), *output[0].size())
75                output = output.transpose(0, 1)
76                return output
77
78        input_size = 8
79        batch_size = 2
80        seq_len = 130
81
82        rec = Recurrence(seq_len)
83        input = torch.rand(batch_size, seq_len, input_size)
84
85        torch.cuda.set_device(0)
86        rec = rec.cuda()
87        input = input.cuda()
88
89        traced_rec = torch.jit.trace(rec, (input))
90
91    def test_trace_legacy_ctor(self):
92        class MyModule(nn.Module):
93            def forward(self, x):
94                return (x + 1, torch.FloatTensor([0]))
95
96        traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2))
97
98    def test_simple(self):
99        x = torch.tensor([0.4], requires_grad=True)
100        y = torch.tensor([0.7], requires_grad=True)
101
102        def f(x, y):
103            return torch.sigmoid(torch.tanh(x * (x + y)))
104
105        self.checkTrace(f, (x, y))
106
107    def test_trace_checking_with_global_name(self):
108        class MyClass(torch.nn.Module):
109            def forward(self, xs: List[Tensor]):
110                y = torch.cat(xs, dim=0)
111                return y
112
113        model = MyClass()
114        # Simulate these inputs being in the globals, like they would be if,
115        # e.g. they were defined outermost scope of a script
116        global input1, input2
117        input1 = torch.ones(2, 2)
118        input2 = torch.ones(2, 2)
119        m2 = torch.jit.trace(model, ((input1, input2),))
120
121    def test_trace_aliased_parameter(self):
122        class M(nn.Module):
123            def __init__(self, x):
124                super().__init__()
125                self.x = nn.Parameter(x)
126
127            def forward(self, y):
128                return self.x + y
129
130        m = M(torch.rand(3, 4))
131        r = torch.jit.trace(m, m.x)
132        t2 = torch.rand(3, 4)
133        self.assertEqual(r(t2), m.x + t2)
134
135    def test_trace_nested_fn(self):
136        class TracedInlineDecision(torch.nn.Module):
137            def forward(self, x, flag):
138                @torch.jit.script
139                def make_decision(flag, x):
140                    if flag:
141                        return x
142                    else:
143                        return torch.zeros_like(x)
144
145                x = torch.neg(x)
146                return make_decision(flag, x)
147
148        decision = TracedInlineDecision()
149        torch.jit.trace(
150            decision,
151            (torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)),
152            check_trace=True,
153        )
154
155    def test_trace_single_tuple(self):
156        x = torch.tensor(2.0)
157
158        def f2(x):
159            return (x,)
160
161        jit_f2 = torch.jit.trace(f2, x)
162        assert f2(x) == jit_f2(x)  # fails
163
164    def test_trace_out_operator_with_two_output(self):
165        example_input = torch.rand(2, 8)
166        out_1, out_2 = torch.cummax(example_input, 1)
167
168        def run_cummax(example_input, out_1, out_2):
169            output_1, output_2 = torch.cummax(example_input, 1, out=(out_1, out_2))
170            return output_1, output_2
171
172        trace_model = torch.jit.trace(run_cummax, (example_input, out_1, out_2))
173
174    def test_trace_namedtuple(self):
175        Point = namedtuple("point", ["x", "y"])
176
177        def f(p):
178            if type(p) is tuple:
179                p = Point(*p)
180            return p.x + p.y
181
182        p = Point(torch.randn(1), torch.randn(1))
183        traced = torch.jit.trace(f, (p,))
184        self.assertEqual(f(p), traced(p))
185
186    def test_trace_topk(self):
187        class M(torch.nn.Module):
188            def forward(self, x, y):
189                return x.topk(y, dim=1)[1]
190
191        mod = M()
192        inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17))
193        traced_func = torch.jit.trace(mod, inputs)
194
195        test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8))
196        eager_out = mod(*test_inputs)
197        traced_out = traced_func(*test_inputs)
198        self.assertNotWarn(
199            lambda: traced_func(*test_inputs),
200            "Shouldn't throw slicing related warn here",
201        )
202        self.assertEqual(eager_out, traced_out)
203
204        test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12))
205        eager_out = mod(*test_inputs)
206        traced_out = traced_func(*test_inputs)
207        self.assertNotWarn(
208            lambda: traced_func(*test_inputs),
209            "Shouldn't throw slicing related warn here",
210        )
211        self.assertEqual(eager_out, traced_out)
212
213    def test_typeas_trace_check(self):
214        a = torch.tensor([0.4], requires_grad=True)
215        b = torch.tensor([0.7], requires_grad=True)
216
217        def f(x, y):
218            return x.type_as(y)
219
220        trace = torch.jit.trace(f, (a, b))
221
222    def test_trace_index(self):
223        x = torch.tensor([0.4], requires_grad=True)
224        y = torch.tensor([0], dtype=torch.int64)
225
226        def fn(x, y):
227            return x[y]
228
229        fn_traced = torch.jit.trace(
230            fn,
231            (
232                x,
233                y,
234            ),
235        )
236
237        self.assertEqual(fn(x, y), fn_traced(x, y))
238
239    # Backwards tracing was broken for indexing by a constant,
240    # because it's internally implemented using as_strided,
241    # and we attempted to trace its derivative (which is not
242    # currently supported.)  It currently works because
243    # slice() is now not marked as traceable.
244    def test_trace_index_constant(self):
245        x = torch.tensor([0.4], requires_grad=True)
246
247        def fn(x):
248            return x[0]
249
250        def run(f):
251            y = f(x)
252            grad = torch.autograd.grad(y, x)[0].clone()
253            return y, grad
254
255        traced_fn = torch.jit.trace(fn, torch.ones(1))
256        self.assertEqual(run(fn), run(traced_fn))
257
258    def test_index_put(self):
259        ten = torch.zeros(3, 3)
260        mask = torch.tensor(
261            [[True, True, True], [True, False, False], [True, True, False]]
262        )
263
264        def test_fn(ten, mask):
265            ten[mask] = torch.ones(6)
266            return ten
267
268        traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
269
270        ten = torch.rand(3, 3)
271        self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
272
273    def test_canonicalize_tensor_iterator(self):
274        x = torch.randn(4, 4)
275
276        def f(x):
277            x = x + 2
278            x = x - 4
279            x = x * 6
280            x = x / 8
281            return x
282
283        traced = torch.jit.trace(f, (x,))
284        f(x)
285        graph = traced.graph_for(x)
286        # There should be 4 int constants for the right sides of operators, plus one
287        # for the alpha argument for add and sub
288        self.assertTrue(str(traced.graph_for(x)).count(": int = prim::Constant") == 5)
289
290    @suppress_warnings
291    def test_constant(self):
292        x = torch.randn(2, 2, requires_grad=True)
293
294        def f(x):
295            return x.matmul(torch.diag(torch.tensor([2.0, 2.0])))
296
297        self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
298
299    def test_wrapped_number(self):
300        # Scalar's get converted to 'wrapped' tensors of default tensor type.
301        # Wrapped tensors behave differently in certain promotion operations:
302        # float_tensor * double -> float but wrapped_float * double -> double.
303        # This can cause issues in check-trace if not handled correctly in
304        # `aten::isclose()`.
305
306        def foobar():
307            x = -10000.0
308            result = x * torch.ones(1, dtype=torch.float)
309            return result
310
311        scripted = torch.jit.trace(foobar, (), check_trace=True)
312
313    def test_inplace_transplant(self):
314        x = torch.tensor([0.0], requires_grad=True)
315
316        def fn(x):
317            y = x.clone()
318            y.add_(2)
319            y.add_(3)
320            return y
321
322        g, _ = torch.jit._get_trace_graph(fn, (x,))
323        self.run_pass("dce", g)
324        FileCheck().check_count("aten::clone", 1, exactly=True).check_count(
325            "aten::add_", 2, exactly=True
326        ).check_next("return").run(str(g))
327        self.assertExportImport(g, (x,))
328
329    def test_inplace_flags(self):
330        class InplaceFn(Function):
331            @staticmethod
332            def forward(ctx, x):
333                ctx.mark_dirty(x)
334                return x.add_(1)
335
336            @staticmethod
337            def backward(ctx, go):
338                return go
339
340        class RegularFn(Function):
341            @staticmethod
342            def forward(ctx, x):
343                return x.add(1)
344
345            @staticmethod
346            def backward(ctx, go):
347                return go
348
349        x = torch.tensor([0.0], requires_grad=True)
350
351        def fn(x):
352            y = RegularFn.apply(x)
353            y = InplaceFn.apply(y)
354            y = InplaceFn.apply(y)
355            y = RegularFn.apply(y)
356            return y
357
358        trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True)
359        self.run_pass("dce", trace_graph)
360        ops = list(trace_graph.nodes())
361        for op in ops:
362            self.assertTrue(op.hasAttribute("inplace"))
363        inplace_flags = [False, True, True, False]
364        for op, is_inplace in zip(ops, inplace_flags):
365            self.assertEqual(op.i("inplace"), is_inplace)
366
367    def test_inplace_check(self):
368        class MyInplaceFn(Function):
369            @staticmethod
370            def forward(self, x):
371                x.add_(1)
372                self.mark_dirty(x)
373                return x
374
375            @staticmethod
376            def backward(self, grad):
377                return grad
378
379        def fn(x):
380            return MyInplaceFn.apply(x)
381
382        x = torch.randn(5, 5)
383        ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
384        with self.assertRaisesRegex(RuntimeError, "inplace MyInplaceFn"):
385            ge(x)
386
387    def test_force_outplace_check_fill(self):
388        def f(x):
389            return torch.empty(x.shape).fill_(7)
390
391        x = torch.randn(10, 15)
392        ft = torch.jit.trace(f, x, _force_outplace=True)
393        self.assertEqual(f(x), ft(x))
394
395    def test_force_outplace_check_zero(self):
396        def f(x):
397            return torch.empty(x.shape).zero_()
398
399        x = torch.randn(10, 15)
400        ft = torch.jit.trace(f, x, _force_outplace=True)
401        self.assertEqual(f(x), ft(x))
402
403    def do_trace_size(self, requires_grad):
404        def fn(x):
405            return x.view(x.shape[1] * 2, x.size(0), 2)
406
407        x = torch.randn(5, 2, 4, requires_grad=requires_grad)
408        y = torch.randn(4, 8, 4, requires_grad=requires_grad)
409
410        # Check that it behaves as expected
411        traced_fn = torch.jit.trace(fn, x)
412        self.assertEqual(traced_fn(y), fn(y))
413        self.assertEqual(traced_fn(x), fn(x))
414
415    def test_trace_size(self):
416        self.do_trace_size(False)
417
418    # test the different graph_executor path that happens when
419    # gradients are required and sizes are involved
420    def test_trace_size_with_grad(self):
421        self.do_trace_size(True)
422
423    def test_trace_numel(self):
424        def fn(x):
425            return x.numel()
426
427        x = torch.randn(2, 3, 4)
428        y = torch.randn(4, 5, 6)
429
430        traced_fn = torch.jit.trace(fn, x)
431        self.assertEqual(traced_fn(y), fn(y))
432        self.assertEqual(traced_fn(x), fn(x))
433
434    def do_trace_arange(self, requires_grad):
435        def arange(x):
436            return torch.arange(x.shape[0])
437
438        def arange_scalar(x):
439            return torch.arange(12)
440
441        def arange_start_end(x):
442            return torch.arange(start=x.shape[0], end=x.shape[0] + 5)
443
444        x = torch.randn(5, 3, 2, requires_grad=requires_grad)
445        y = torch.randn(8, 2, 4, requires_grad=requires_grad)
446
447        # Check that it behaves as expected
448        traced_arange = torch.jit.trace(arange, x)
449        self.assertEqual(traced_arange(y), arange(y))
450        self.assertEqual(traced_arange(x), arange(x))
451
452        traced_arange_scalar = torch.jit.trace(arange_scalar, x)
453        self.assertEqual(traced_arange_scalar(y), arange_scalar(y))
454        self.assertEqual(traced_arange_scalar(x), arange_scalar(x))
455
456        traced_arange_start_end = torch.jit.trace(arange_start_end, x)
457        self.assertEqual(traced_arange_start_end(y), arange_start_end(y))
458        self.assertEqual(traced_arange_start_end(x), arange_start_end(x))
459
460    def test_trace_arange(self):
461        self.do_trace_arange(False)
462
463    # test the different graph_executor path that happens when
464    # gradients are required and sizes are involved
465    def test_trace_arange_with_grad(self):
466        self.do_trace_arange(True)
467
468    # Test that a trace of torch.full(x.shape) doesn't store the shape as a constant
469    def test_trace_full_dynamic_shape(self):
470        def full_with_shape_like(x):
471            return torch.full(x.shape, 2.0)
472
473        x = torch.randn(3, 4)
474        ge = torch.jit.trace(full_with_shape_like, example_inputs=x)
475        y = torch.randn(2, 7)
476        self.assertEqual(ge(y).shape, y.shape)
477        self.assertEqual(ge(x).shape, x.shape)
478
479    # Test that the trace of setitem doesn't store shapes as constants
480    # Fix https://github.com/pytorch/pytorch/issues/43548
481    def test_trace_slice_setitem_dynamic_shape(self):
482        def slice_setitem(x, y):
483            x[:, 2] = y + 1
484            return x
485
486        x = torch.randn(3, 4)
487        traced = torch.jit.trace(slice_setitem, (x, x[:, 0]))
488        x = torch.randn(10, 5)
489        self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0]))
490
491    # Suppression: we are intentionally slicing a tensor, we don't care that it
492    # will be constantified
493    @suppress_warnings
494    def do_trace_slice(self, requires_grad):
495        def slice(x):
496            results = []
497            for i in range(4):
498                results.append(x[: x.size(0) - i, i : x.size(2), i:3])
499            return tuple(results)
500
501        def slice_select(x):
502            results = []
503            for i in range(4):
504                results.append(x[:, i:, x.size(2) - 5])
505            return tuple(results)
506
507        x = torch.randn(5, 6, 7, requires_grad=requires_grad)
508        y = torch.randn(7, 8, 9, requires_grad=requires_grad)
509
510        # Check that it behaves as expected
511        traced_slice = torch.jit.trace(slice, x)
512        self.assertEqual(traced_slice(y), slice(y))
513        self.assertEqual(traced_slice(x), slice(x))
514
515        traced_slice_select = torch.jit.trace(slice_select, x)
516        self.assertEqual(traced_slice_select(y), slice_select(y))
517        self.assertEqual(traced_slice_select(x), slice_select(x))
518
519    def test_trace_slice(self):
520        self.do_trace_slice(False)
521
522    # test the different graph_executor path that happens when
523    # gradients are required and sizes are involved
524    def test_trace_slice_with_grad(self):
525        self.do_trace_slice(True)
526
527    def test_trace_casts(self):
528        casts = [
529            lambda x: x.byte(),
530            lambda x: x.float(),
531            lambda x: x.cpu(),
532            lambda x: x.to(device="cpu"),
533            lambda x: x.to(dtype=torch.int64),
534            lambda x: x.to(device="cpu", dtype=torch.float),
535            lambda x: x.to(x),
536        ]
537
538        def assertContainsCast(trace):
539            self.assertEqual(
540                sum(n.kind() == "aten::to" for n in trace.graph.nodes()), 1
541            )
542
543        for cast in casts:
544            trace = torch.jit.trace(cast, torch.randn(2, 2))
545            assertContainsCast(trace)
546            x = torch.randn(2, 2)
547            self.assertEqual(trace(x), cast(x))
548
549        def to_tensor(x, y):
550            return x.to(y)
551
552        to_tensor_trace = torch.jit.trace(
553            to_tensor, (torch.randn(2, 2), torch.randn(1, 8))
554        )
555        assertContainsCast(to_tensor_trace)
556        x, y = torch.randn(2, 2), torch.randn(1, 10)
557        self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
558
559    @skipIfCompiledWithoutNumpy
560    @skipIfCrossRef
561    def test_trace_warn(self):
562        def fn(x):
563            int(x)  # Warning 1.
564            y = x * 1
565            if y:  # Warning 2.
566                pass
567            q = [x, x * 4]
568            z = q[y]
569            float(z)  # Warning 3.
570            z.tolist()  # Warning 4.
571            z.numpy()  # Warning 5.
572            for _ in torch.ones(4, 4):  # Warning 6.
573                pass
574            return z + 4
575
576        with warnings.catch_warnings(record=True) as warns:
577            traced_fn = torch.jit.trace(fn, torch.tensor([1]))
578        for warn in warns:
579            self.assertIs(warn.category, torch.jit.TracerWarning)
580        warns = [str(w.message) for w in warns]
581        self.assertIn("a Python integer", warns[0])
582        self.assertIn("a Python boolean", warns[1])
583        self.assertIn("a Python float", warns[2])
584        self.assertIn("a Python list", warns[3])
585        self.assertIn("a NumPy array", warns[4])
586        self.assertIn("Iterating over", warns[5])
587
588    def test_trace_tuple(self):
589        def fn(x, y):
590            return x, (x * y[1], x * y[0])
591
592        x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
593        traced_fn = torch.jit.trace(fn, (x, y))
594        self.assertEqual(traced_fn(x, y), fn(x, y))
595        # should be a tuple nested within another tuple
596        FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next(
597            "return"
598        ).run(str(traced_fn.graph))
599        self.assertExportImport(traced_fn.graph, (x, y))
600
601    def test_trace_random(self):
602        def f(mean, std):
603            return torch.normal(mean, std)
604
605        traced = torch.jit.trace(
606            f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False
607        )
608        mean, std = torch.zeros(5, 5), torch.ones(5, 5)
609        with torch.random.fork_rng(devices=[]):
610            output = f(mean, std)
611        traced_output = traced(mean, std)
612        self.assertEqual(output, traced_output)
613
614    def test_trace_tensor_factory(self):
615        def run(**kwargs):
616            inputs_require_grads = kwargs.pop("inputs_require_grads", True)
617
618            def fn(x):
619                return x + torch.ones(2, 3, **kwargs)
620
621            input_kwargs = kwargs.copy()
622            if "out" in input_kwargs:
623                del input_kwargs["out"]
624            input = torch.ones(2, 3, **input_kwargs)
625            self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
626            # check we recorded 'ones' and did not just record a constant
627            tfn = torch.jit.trace(fn, input)
628            self.assertTrue("ones" in str(tfn.graph))
629
630        run()
631        run(dtype=torch.int, inputs_require_grads=False)
632        run(out=torch.tensor([]))
633        if RUN_CUDA:
634            run(device="cuda:0")
635        if RUN_CUDA_MULTI_GPU:
636            run(device="cuda:1")
637
638    def test_trace_indexed_assignment(self):
639        def stuff(x, y):
640            x = x.clone()
641            x[0] = y
642            return x
643
644        example = torch.rand(3, 4)
645        self.checkTrace(stuff, (example, example[0] + 1))
646
647    # TODO: implement
648    @unittest.expectedFailure
649    def test_output_unflatten(self):
650        """Check that outputs of traced functions retain the original structure and nesting"""
651
652        def fn(x):
653            return (
654                x * 2,
655                (
656                    x**2,
657                    x + 4,
658                    (x + 2,),
659                ),
660                x * 4,
661            )
662
663        self.checkTrace(fn, (torch.randn(2, 2),))
664
665    def test_input_flatten(self):
666        """Check that inputs to traced functions are flattened"""
667
668        def fn(x, t):
669            y, z = t
670            return x * y * z
671
672        inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
673        self.checkTrace(fn, inputs)
674
675    def test_input_dict_empty(self):
676        def test(d):
677            pass
678
679        with self.assertRaises(RuntimeError):
680            self.checkTrace(test, {})
681
682    def test_input_dict_remembers_keys(self):
683        """Check that the trace remembers which keys were in a dict input"""
684
685        class TestModule(torch.nn.Module):
686            def forward(self, dict_input):
687                return dict_input["x"]
688
689        input_1 = {"x": torch.tensor(1)}
690        m = TestModule()
691        m_traced = torch.jit.trace(m, (input_1,))
692        self.assertEqual(m_traced(input_1), torch.tensor(1))
693
694        # should work to change the values and not the keys
695        input_same_key_different_value = {"x": torch.tensor(2)}
696        self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2))
697
698        # error to use something that doesn't have `x`
699        input_different_key = {"y": torch.tensor(3)}
700        with self.assertRaises(RuntimeError):
701            m_traced(input_different_key)
702
703        # it's okay to have additional elements in the dictionary, so long as 'x' is there
704        input_additional_key = {"x": torch.tensor(4), "y": torch.tensor(3)}
705        self.assertEqual(m_traced(input_additional_key), torch.tensor(4))
706
707    def test_input_dict_insertion_order(self):
708        """Check that dictionary access doesn't care about insertion order"""
709
710        class TestModule(torch.nn.Module):
711            def forward(self, dict_input):
712                return dict_input["x"], dict_input["y"]
713
714        input_x_then_y = {}
715        input_x_then_y["x"] = torch.tensor(1)
716        input_x_then_y["y"] = torch.tensor(2)
717
718        m = TestModule()
719        m_traced = torch.jit.trace(m, (input_x_then_y,))
720
721        self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2)))
722
723        input_y_then_x = {}
724        input_y_then_x["y"] = torch.tensor(4)
725        input_y_then_x["x"] = torch.tensor(3)
726
727        self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4)))
728
729    def test_input_dict_recursive(self):
730        class TestModule(torch.nn.Module):
731            def forward(self, dict_input):
732                return dict_input["x"][1]
733
734        input_1 = {"x": {1: torch.tensor(1)}}
735        m = TestModule()
736        m_traced = torch.jit.trace(m, (input_1,))
737
738        input_2 = {"x": {1: torch.tensor(2)}}
739        self.assertEqual(m_traced(input_2), torch.tensor(2))
740
741    def test_input_dict_checkTrace_mut(self):
742        def test(d):
743            d["x"].tanh_()
744            return d["x"]
745
746        inputs = {"x": torch.rand(3, 4), "y": torch.rand(3, 4)}
747        self.checkTrace(test, (inputs,), inputs_require_grads=False)
748
749    def test_input_dict_unify(self):
750        def test(d):
751            return d["int"], d["float"]
752
753        inputs = {
754            "int": torch.ones((2, 2), dtype=torch.int32),
755            "float": torch.ones((2, 2), dtype=torch.float32),
756        }
757        self.checkTrace(test, (inputs,), inputs_require_grads=False)
758
759    def test_input_tuple_of_dicts(self):
760        def test(t):
761            d = t[0]
762            return d["x"]["y"]
763
764        inputs = {"x": {"y": torch.rand(2, 3)}}
765        self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
766
767    def test_input_dict_of_dicts(self):
768        def test(d):
769            return d["x"]["y"]
770
771        nested_input = {"y": torch.rand(2, 3)}
772        unified_nested = {"y": torch.rand(3, 2)}
773        inputs = {"x": nested_input, "force_unify": unified_nested}
774        self.checkTrace(test, (inputs,), allow_unused=True)
775
776    def test_input_dict_of_lists(self):
777        def test(d):
778            return d["x"][0]
779
780        inputs = {"x": [torch.rand(3, 2)]}
781        self.checkTrace(test, (inputs,))
782
783    def test_input_list_toplevel_flatten(self):
784        def test(t1, t2):
785            return torch.add(t1, t2)
786
787        inputs = [torch.ones(2, 2), torch.rand(2, 2)]
788        self.checkTrace(test, inputs)
789
790    def test_input_list_toplevel_flatten_direct(self):
791        class Test(torch.nn.Module):
792            def forward(self, t1, t2):
793                return torch.add(t1, t2)
794
795        inputs = [torch.ones(2, 2), torch.rand(2, 2)]
796        torch.jit.trace(Test(), inputs)
797
798    def test_input_list_of_tuples(self):
799        def test(l):
800            return l[0][0]
801
802        inputs = [(torch.ones(2, 2),)]
803        self.checkTrace(test, (inputs,))
804
805    def test_input_dict_empty_list(self):
806        def test(d):
807            pass
808
809        inputs = {1: []}
810        with self.assertRaisesRegex(RuntimeError, "List trace"):
811            self.checkTrace(test, (inputs,))
812
813    def test_input_list_mixed_type(self):
814        def test(d):
815            pass
816
817        inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))]
818        with self.assertRaisesRegex(RuntimeError, "consistent"):
819            self.checkTrace(test, (inputs,))
820
821    def test_conv(self):
822        x = torch.ones(20, 16, 50, 40)
823        g, outputs, inputs = torch.jit._get_trace_graph(
824            nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True
825        )
826        m = self.createFunctionFromGraph(g)
827        self.assertEqual(outputs, m(*inputs))
828
829    def test_max_pool(self):
830        x = torch.rand(20, 16, 10, 10)
831
832        def max_pool2d(x):
833            return F.max_pool2d(x, 2) + 2
834
835        trace = torch.jit.trace(max_pool2d, (x))
836        graph = trace.graph_for(x)
837        FileCheck().check("aten::max_pool2d(").run(graph)
838        self.assertEqual(max_pool2d(x), trace(x))
839
840    def test_nested_inplace(self):
841        x = torch.randn(2, 2)
842        g, outputs, inputs = torch.jit._get_trace_graph(
843            lambda x: F.threshold(x, 0, 0, inplace=True), (x,), return_inputs=True
844        )
845        m = self.createFunctionFromGraph(g)
846        self.assertEqual(outputs, m(*inputs))
847        FileCheck().check("threshold_").run(str(g))
848        self.assertExportImport(g, (x,))
849
850    def test_repeated_input(self):
851        def fn(a, b):
852            return a + b
853
854        ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
855        inputs = set(ge.graph.inputs())
856        # three instead of 2 because the export/import in checkTrace adds a
857        # `self` module argument
858        self.assertTrue(len(inputs) == 3)
859
860    def test_repeated_output(self):
861        def fn(a, b):
862            z = a + b
863            return z, z
864
865        ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
866        tuple_output = list(ge.graph.outputs())[0]
867        tuple_inputs = list(tuple_output.node().inputs())
868        self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
869
870    def test_inplace_copy(self):
871        x = torch.randn(4, 4, requires_grad=True)
872
873        def f(x):
874            out = torch.zeros(x.size())
875            out.copy_(x)
876            return out
877
878        g, outputs, inputs = torch.jit._get_trace_graph(f, (x,), return_inputs=True)
879        self.run_pass("dce", g)
880        m = self.createFunctionFromGraph(g)
881        self.assertEqual(outputs, m(*inputs))
882        self.assertExportImport(g, (x,))
883
884    def test_inplace_copy_force_outplace(self):
885        x = torch.randn(4, 4, requires_grad=True)
886
887        def f(x):
888            out = torch.zeros(x.size())
889            out.copy_(x)
890            return out
891
892        g, outputs, inputs = torch.jit._get_trace_graph(
893            f, (x,), return_inputs=True, _force_outplace=True
894        )
895        self.run_pass("dce", g)
896        m = self.createFunctionFromGraph(g)
897        self.assertEqual(outputs, m(*inputs))
898        self.assertExportImport(g, (x,))
899        FileCheck().check("expand_as").run(str(g))
900
901    def test_shared_param(self):
902        class MyModule(torch.nn.Module):
903            def __init__(self) -> None:
904                super().__init__()
905                self.b = self.a = nn.Parameter(torch.randn(2, 2))
906
907            def forward(self, x):
908                return x * self.a + self.b
909
910        m = MyModule()
911        g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),))
912        self.run_pass("dce", g)
913        self.assertEqual(len(list(g.inputs())), 2)
914        FileCheck().check("mul").check("add").run(str(g))
915
916    def run_ge_tests(self, optimize, use_cuda):
917        with enable_profiling_mode_for_profiling_tests():
918            with torch.jit.optimized_execution(optimize):
919
920                def rand(*args):
921                    t = torch.rand(*args).float()
922                    if use_cuda:
923                        t = t.cuda()
924                    return t
925
926                self.checkTrace(
927                    lambda a, b: a * b + b, [rand(1), rand(1)], [rand(2, 3), rand(2, 3)]
928                )
929                # trivial identity
930                self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)])
931
932                def foo(a):
933                    t = a * a
934                    return t * t, 4 * t
935
936                self.checkTrace(foo, [rand(1)])
937                # unused input
938                self.checkTrace(
939                    lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True
940                )
941                # test outputs that do not get used in grad
942                self.checkTrace(foo, [rand(1)], drop=1)
943                # test autograd fallback
944                self.checkTrace(
945                    lambda a, b: a * b / (a - 2 * b) + b, [rand(1), rand(1)]
946                )
947
948    def test_ge_unoptimized(self):
949        self.run_ge_tests(False, False)
950
951    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
952    @enable_cpu_fuser
953    def test_ge_optimized(self):
954        with enable_profiling_mode_for_profiling_tests():
955            self.run_ge_tests(True, False)
956
957    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
958    def test_ge_cuda(self):
959        self.run_ge_tests(True, True)
960
961    # more manual test of graph executor that can be used as a scratchpad
962    def test_ge(self):
963        def foo(a, b):
964            return a * b / (a - b) + b
965
966        V = Variable
967        a, b = V(torch.rand(1)), V(torch.rand(1))
968        ge = torch.jit.trace(foo, (a, b))
969        a, b = V(torch.rand(1), requires_grad=True), V(
970            torch.rand(1), requires_grad=True
971        )
972        (r,) = ge(a, b)
973        da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
974
975        l2 = da * db + db * db
976        g2result = torch.autograd.grad(l2, [da, db])
977
978        r = foo(a, b)
979        da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
980        self.assertEqual(da, da2)
981        self.assertEqual(db, db2)
982        l3 = da2 * db2 + db2 * db2
983        g2result2 = torch.autograd.grad(l3, [da2, db2])
984        self.assertEqual(g2result, g2result2)
985
986    def test_trace_annotation(self):
987        @_trace(torch.rand(1))
988        def foo(a):
989            return a + a + a
990
991        x = torch.randn(5, 5)
992        self.assertEqual(foo(x), x + x + x)
993
994    @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
995    # By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision.
996    # We want float tensors to be computed at full precision in order to use the default precision
997    @with_tf32_off
998    def test_traced_module_cuda(self):
999        class Model(nn.Module):
1000            def __init__(self, num_features, num_layers):
1001                super().__init__()
1002                self.num_layers = num_layers
1003                layers = [
1004                    [nn.Linear(num_features, num_features), nn.Sigmoid()]
1005                    for _ in range(num_layers)
1006                ]
1007                self.submodule = nn.Sequential(*chain(*layers))
1008
1009            def forward(self, x):
1010                for i in range(self.num_layers):
1011                    x = self.submodule[i](x) + x
1012                return x
1013
1014        model = Model(5, 3)
1015        x = torch.randn(2, 5)
1016        traced_model = torch.jit.trace(model, x)
1017
1018        # We're missing some attributes these modules had initially. Make sure we can
1019        # still get the __repr__()
1020        model.__repr__()
1021
1022        # XXX: indexing sequentials is broken
1023        linear_submodule = next(iter(traced_model.submodule._modules.values()))
1024
1025        # All attributes that aren't parameters should raise
1026        with self.assertRaises(AttributeError):
1027            linear_submodule.in_features
1028        linear_submodule.weight
1029        linear_submodule.weight = nn.Parameter(
1030            torch.randn(linear_submodule.weight.shape)
1031        )
1032        with self.assertRaises(RuntimeError):
1033            del linear_submodule.weight
1034
1035        # Submodules can't be called
1036        with self.assertRaises(RuntimeError):
1037            linear_submodule(x)
1038
1039        # Type casts
1040        linear_submodule.cuda()
1041        traced_model.float().cuda()
1042        cuda_out = traced_model(x.float().cuda())
1043        traced_model.cpu()
1044        cpu_out = traced_model(x.float())
1045        self.assertEqual(cpu_out, cuda_out)
1046        traced_model.to("cuda")
1047        cuda_out = traced_model(x.float().cuda())
1048        traced_model.to("cpu")
1049        cpu_out = traced_model(x.float())
1050        self.assertEqual(cpu_out, cuda_out)
1051        traced_model.to(torch.get_default_dtype())
1052
1053        # state_dict + load_state_dict
1054        state = {k: v.clone() for k, v in traced_model.state_dict().items()}
1055        new_state = {k: v.clone().fill_(1) for k, v in state.items()}
1056        out = traced_model(x)
1057        traced_model.load_state_dict(new_state)
1058        out_ones = traced_model(x)
1059        traced_model.load_state_dict(state)
1060        out_state = traced_model(x)
1061        self.assertEqual(out, out_state)
1062        self.assertNotEqual(out, out_ones)
1063
1064    @unittest.skipIf(not RUN_CUDA, "uses cuda")
1065    def test_type_same_device(self):
1066        class Model(torch.nn.Module):
1067            def __init__(self) -> None:
1068                super().__init__()
1069                self.dtype = torch.float16
1070
1071            def forward(self, x=None):
1072                h = x.type(self.dtype)
1073                return h
1074
1075        a = Model()
1076        b = torch.jit.trace(
1077            a, example_inputs=(torch.ones([1], device=torch.device("cuda")),)
1078        )
1079        FileCheck().check_not("device").run(b.code)
1080
1081    def test_export_no_reorder(self):
1082        def func(a, b):
1083            return a * b / (a - 2 * b) + b
1084
1085        recording_inputs = [
1086            torch.tensor(
1087                [0.55619788169860839844], dtype=torch.float32, requires_grad=True
1088            ),
1089            torch.tensor(
1090                [0.25947844982147216797], dtype=torch.float32, requires_grad=True
1091            ),
1092        ]
1093
1094        ge1 = torch.jit.trace(func, recording_inputs)
1095        ge2 = self.getExportImportCopy(ge1)
1096
1097        outputs_ge1 = ge1(*recording_inputs)
1098        outputs_ge2 = ge2(*recording_inputs)
1099
1100        grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
1101        grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
1102        self.assertTrue(outputs_ge1 == outputs_ge2)
1103        self.assertTrue(grad_ge1 == grad_ge2)
1104
1105    def test_python_function(self):
1106        class MyFn(Function):
1107            @staticmethod
1108            def forward(ctx, x):
1109                return x + 1
1110
1111            @staticmethod
1112            def backward(ctx, grad_output):
1113                return grad_output
1114
1115        @_trace(torch.zeros(2))
1116        def fn(x):
1117            return MyFn.apply(x + 2) + 3
1118
1119        x = torch.tensor([1.0, 2.0, 3.0])
1120        y = torch.randn(2, 2, requires_grad=True)
1121        fn(x)
1122        fn(y)
1123
1124    def test_python_function_tup(self):
1125        class MyFn(Function):
1126            @staticmethod
1127            def forward(ctx, x):
1128                return x + 1, x - 1
1129
1130            @staticmethod
1131            def backward(ctx, grad_output):
1132                return grad_output, grad_output
1133
1134        @_trace(torch.zeros(2))
1135        def fn(x):
1136            a, b = MyFn.apply(x + 2)
1137            return a + b + 3
1138
1139        x = torch.tensor([1.0, 2.0, 3.0])
1140        y = torch.randn(2, 2, requires_grad=True)
1141        fn(x)
1142        fn(y)
1143
1144    def test_trace_detach(self):
1145        def foo(x, w):
1146            return torch.matmul(x, w).detach()
1147
1148        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1149
1150        FileCheck().check("matmul").check("detach").run(str(traced.graph))
1151        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1152        traced_result = traced(x, w)
1153        self.assertEqual(foo(x, w), traced_result)
1154        self.assertFalse(traced_result.requires_grad)
1155        self.assertIsNone(traced_result.grad_fn)
1156
1157    def test_trace_detach_redispatch(self):
1158        def foo(x, w):
1159            y = torch.matmul(x, w)
1160            assert y.requires_grad
1161            y = y.detach()
1162            # Make sure trace kernel redispatches to the right lower kernel.
1163            assert not y.requires_grad
1164            return y
1165
1166        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1167        # With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
1168        torch.jit.trace(foo, (x, w), check_trace=False)
1169
1170    def test_trace_detach_inplace(self):
1171        def foo(x, w):
1172            y = torch.matmul(x, w)
1173            y.detach_()
1174            return y
1175
1176        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1177
1178        FileCheck().check("matmul").check("detach(").run(str(traced.graph))
1179        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1180        traced_result = traced(x, w)
1181        self.assertEqual(foo(x, w), traced_result)
1182        self.assertFalse(traced_result.requires_grad)
1183        self.assertIsNone(traced_result.grad_fn)
1184
1185    def test_trace_detach_inplace_redispatch(self):
1186        def foo(x, w):
1187            y = torch.matmul(x, w)
1188            assert y.requires_grad
1189            y.detach_()
1190            # Make sure trace kernel redispatches to the right lower kernel.
1191            assert not y.requires_grad
1192            return y
1193
1194        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1195        # With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
1196        torch.jit.trace(foo, (x, w), check_trace=False)
1197
1198    def test_trace_slice_full_dim(self):
1199        def foo(x):
1200            return x[0:5, 0] + 1.0
1201
1202        traced = torch.jit.trace(foo, (torch.rand(5, 4),))
1203        test_x = torch.rand(6, 3)
1204        self.assertEqual(foo(test_x), traced(test_x))
1205
1206    def test_trace_dict_input(self):
1207        class Bar(torch.nn.Module):
1208            def __init__(self) -> None:
1209                super().__init__()
1210                self.foo = Foo()
1211
1212            def forward(self, a, b):
1213                return self.foo({"a": a, "b": b})["a"]
1214
1215        class Foo(torch.nn.Module):
1216            def forward(self, x):
1217                return {"a": x["a"] * x["b"]}
1218
1219        x = (torch.rand(3), torch.rand(3))
1220        model = Bar()
1221        self.checkTrace(model, x)
1222
1223    def test_trace_dict_output(self):
1224        class TraceDictStrTensor(torch.nn.Module):
1225            def forward(self, a, b):
1226                return {"a": a, "b": b}
1227
1228        class TraceDictTensorTensor(torch.nn.Module):
1229            def forward(self, a, b):
1230                return {a: b, b: a}
1231
1232        x = (torch.rand(3), torch.rand(3))
1233        with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"):
1234            torch.jit.trace(TraceDictStrTensor(), x)
1235
1236        traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False)
1237        self.assertEqual(traced_dict_str_mod(*x), {"a": x[0], "b": x[1]})
1238
1239        traced_dict_tensor_mod = torch.jit.trace(
1240            TraceDictTensorTensor(), x, strict=False
1241        )
1242        self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]})
1243
1244    def test_trace_with_tensor_list_output(self):
1245        def f():
1246            return [torch.zeros(1), torch.zeros(5)]
1247
1248        with self.assertWarnsRegex(
1249            torch.jit.TracerWarning, "cause the trace to be incorrect"
1250        ):
1251            torch.jit.trace(f, [])
1252        traced_non_strict_f = torch.jit.trace(f, [], strict=False)
1253        self.assertEqual(traced_non_strict_f(), f())
1254
1255    def test_trace_with_number_list_output(self):
1256        def f():
1257            return [1, 5]
1258
1259        with self.assertRaisesRegex(
1260            RuntimeError, r"Only tensors.+can be output from traced functions"
1261        ):
1262            traced_f = torch.jit.trace(f, [])
1263
1264    def test_trace_with_nested_tensor_list_output(self):
1265        def f():
1266            return [[torch.zeros(1)], [torch.zeros(5)]]
1267
1268        with self.assertRaisesRegex(
1269            RuntimeError, r"Only tensors.+can be output from traced functions"
1270        ):
1271            traced_f = torch.jit.trace(f, [])
1272
1273    def test_trace_with_nested_strided_tensor_output(self):
1274        @torch.jit.script
1275        def nt_construct(values, kv_lengths):
1276            kv_lengths_list: List[int] = kv_lengths.tolist()
1277            return torch._nested_tensor_from_tensor_list(
1278                list(values.split(kv_lengths_list, dim=0)), None, None, None, None
1279            )
1280
1281        def f(x, offsets):
1282            kv_lengths = offsets[1:] - offsets[:-1]
1283            return nt_construct(x, kv_lengths).cos()
1284
1285        x = torch.rand(5, 4)
1286        offsets = torch.tensor([0, 2, 5])
1287        ref = f(x, offsets)
1288        f_t = torch.jit.trace(f, (x, offsets))
1289        res = f_t(x, offsets)
1290        self.assertEqual(ref, res)
1291        x2 = torch.rand((8, 4))
1292        offsets2 = torch.tensor([0, 2, 4, 8])
1293        self.assertEqual(f(x2, offsets2), f_t(x2, offsets2))
1294
1295    def test_trace_variable_instantiation(self):
1296        def random_foo(x):
1297            return Variable(Variable(x) + 1.0)
1298
1299        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
1300
1301        x = torch.rand(5, 6)
1302        self.assertEqual(random_foo(x), random_foo_traced(x))
1303
1304    def test_trace_slice_expr_complete_type(self):
1305        def random_foo(x):
1306            return x + 1.0
1307
1308        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
1309
1310        @torch.jit.script
1311        def random_bar(x):
1312            return random_foo_traced(x)[0:1]
1313
1314        x = torch.rand(3, 4)
1315        self.assertEqual(random_bar(x), (x + 1)[0:1])
1316
1317    def test_trace_inline_shape(self):
1318        # testing peephole optimization of size is turned into a constant
1319        # in script fn
1320
1321        @torch.jit.script
1322        def tensor_size(x: torch.Tensor) -> torch.Tensor:
1323            return torch.tensor([x.size()[0]])
1324
1325        self.assertEqual(
1326            tensor_size(
1327                torch.rand(
1328                    15,
1329                )
1330            ),
1331            torch.tensor([15]),
1332        )
1333
1334        traced_tensor_size = torch.jit.trace(
1335            tensor_size,
1336            torch.rand(
1337                7,
1338            ),
1339        )
1340
1341        self.assertEqual(
1342            traced_tensor_size(
1343                torch.rand(
1344                    15,
1345                )
1346            ),
1347            torch.tensor([15]),
1348        )
1349
1350        @torch.jit.script
1351        def use_device(x):
1352            return torch.zeros_like(x, device=x.device)
1353
1354        def foo(x):
1355            return use_device(x)
1356
1357        traced_tensor_size = torch.jit.trace(
1358            foo,
1359            torch.rand(
1360                7,
1361            ),
1362        )
1363        self.run_pass("inline", traced_tensor_size.graph)
1364        FileCheck().check("prim::device").run(traced_tensor_size.graph)
1365
1366    def test_trace_save(self):
1367        def fn(x):
1368            return x + 2
1369
1370        def check(func):
1371            with TemporaryFileName() as fname:
1372                func.save(fname)
1373                loaded = torch.jit.load(fname)
1374                input = torch.randn(2, 2)
1375                self.assertEqual(func(input), loaded(input))
1376
1377        out = torch.jit.trace(fn, (torch.ones(2, 2),))
1378        check(out)
1379
1380    def test_trace_optioanl_dtype(self):
1381        class Test(torch.nn.Module):
1382            def forward(self):
1383                return torch.arange(5)
1384
1385        traced = torch.jit.trace(Test(), ())
1386        torch.allclose(traced(), Test()())
1387
1388    def test_trace_save_load_copy(self):
1389        class Test(torch.nn.Module):
1390            def __init__(self) -> None:
1391                super().__init__()
1392                self.conv = torch.nn.Conv2d(3, 3, 3)
1393
1394            def forward(self, x):
1395                return self.conv(x)
1396
1397        traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224))
1398        buffer = io.BytesIO()
1399        torch.jit.save(traced, buffer)
1400        buffer.seek(0)
1401        loaded = torch.jit.load(buffer)
1402        # should work
1403        copy.copy(loaded)
1404        copy.deepcopy(loaded)
1405
1406    def test_trace_export_fns(self):
1407        class Foo(torch.nn.Module):
1408            def __init__(self) -> None:
1409                super().__init__()
1410                self.a = 3
1411
1412            @torch.jit.export
1413            def __getstate__(self):
1414                return (3, self.training)
1415
1416            @torch.jit.export
1417            def __setstate__(self, state):
1418                self.a = state[0]
1419                self.training = state[1]
1420
1421            def forward(self, x):
1422                return x + self.a
1423
1424        f = Foo()
1425
1426        traced = torch.jit.trace(f, (torch.rand(3, 4),))
1427        expected_names = ["__getstate__", "__setstate__"]
1428
1429        def check(mod):
1430            self.assertTrue(
1431                all(name in mod._c._method_names() for name in expected_names)
1432            )
1433
1434        check(traced)
1435
1436        imported = self.getExportImportCopy(traced)
1437        check(imported)
1438
1439    def test_trace_export_fns_recursive(self):
1440        class Foo(torch.nn.Module):
1441            def __init__(self) -> None:
1442                super().__init__()
1443                self.a = 3
1444
1445            @torch.jit.export
1446            def __getstate__(self):
1447                return (3, self.training)
1448
1449            @torch.jit.export
1450            def __setstate__(self, state):
1451                self.a = state[0]
1452                self.training = state[1]
1453
1454            def forward(self, x):
1455                return x + self.a
1456
1457        class Wrapper(torch.nn.Module):
1458            def __init__(self) -> None:
1459                super().__init__()
1460                self.foo = Foo()
1461
1462            def forward(self, x):
1463                return self.foo(x)
1464
1465        f = Wrapper()
1466
1467        traced = torch.jit.trace(f, (torch.rand(3, 4),))
1468        expected_names = ["__getstate__", "__setstate__"]
1469
1470        def check(mod):
1471            self.assertTrue(
1472                all(name in mod._c._method_names() for name in expected_names)
1473            )
1474
1475        check(traced.foo)
1476
1477        imported = self.getExportImportCopy(traced)
1478        check(imported.foo)
1479
1480        # Note that Bar's forward can only be traced, but not scripted
1481        class Bar(nn.Module):
1482            @torch.jit.export
1483            def addTwo(self, x):
1484                return x + 2
1485
1486            def forward(self, input):
1487                return (lambda a: a + 1)(input)  # noqa: PLC3002
1488
1489        # When tracing Bar as a submodule, we only want to script the
1490        # exported methods, and we want to keep the forwards still
1491        # being traced.
1492        class WrapperExports(torch.nn.Module):
1493            def __init__(self) -> None:
1494                super().__init__()
1495                self.bar = Bar()
1496
1497            @torch.jit.export
1498            def addOne(self, x):
1499                return x + 1
1500
1501            def forward(self, x):
1502                return self.bar(x)
1503
1504        f = WrapperExports()
1505
1506        traced = torch.jit.trace(f, (torch.rand(3, 4),))
1507        expected_names = ["addOne"]
1508        check(traced)
1509
1510    def test_trace_autograd_function(self):
1511        class TestFunc(torch.autograd.Function):
1512            @staticmethod
1513            def forward(ctx, input):
1514                return torch.neg(input)
1515
1516            @staticmethod
1517            def backward(ctx, grad_output):
1518                return torch.neg(grad_output)
1519
1520        class TracedModule(torch.nn.Module):
1521            def forward(self, x):
1522                return torch.relu(TestFunc.apply(x))
1523
1524        class Wrapper(torch.nn.Module):
1525            def __init__(self) -> None:
1526                super().__init__()
1527                self.tm = TracedModule()
1528
1529            def forward(self, x):
1530                return self.tm(x)
1531
1532        traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),))
1533
1534    def test_trace_multi_output_function(self):
1535        # An autograd.Function with two outputs.
1536        # It swaps inputs so we can check if shape
1537        # handling is correct in TorchScript.
1538        class Foo(torch.autograd.Function):
1539            @staticmethod
1540            def forward(ctx, x, y):
1541                return y, x
1542
1543            @staticmethod
1544            def backward(ctx, du, dv):
1545                return dv, du
1546
1547        class Bar(torch.nn.Module):
1548            def forward(self, x, y):
1549                x = x.relu()
1550                y = y.relu()
1551                z = Foo.apply(x, y)
1552                return z
1553
1554        x = torch.rand(3, 2, dtype=torch.double)
1555        y = torch.rand(1, 2, dtype=torch.double)
1556
1557        # Generate JIT IR.
1558        traced = torch.jit.trace(Bar(), (x, y))
1559        print(traced.graph)
1560
1561        # Expected output schema of the custom autograd.Function.
1562        schema = (
1563            "(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), "
1564            "Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) "
1565            "= ^Foo"
1566        )
1567
1568        # See if expected schema exists.
1569        FileCheck().check(schema).run(traced.graph)
1570
1571        # Also examine if the graph is runnable and produces
1572        # the right result.
1573        u, v = traced(x, y)
1574        self.assertEqual(u, y)
1575        self.assertEqual(v, x)
1576
1577    def test_interpolate_trace(self):
1578        class test(nn.Module):
1579            def __init__(self) -> None:
1580                super().__init__()
1581                self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)
1582
1583            def forward(self, x):
1584                y = self.conv(x)
1585                w = nn.functional.interpolate(
1586                    y, mode="bilinear", align_corners=False, scale_factor=3
1587                )
1588                return w
1589
1590        f = test()
1591        # no failure
1592        g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
1593        x = torch.zeros(1, 1, 14, 14)
1594        # constants not baked in
1595        self.assertEqual(g(x), f(x))
1596
1597    @_tmp_donotuse_dont_inline_everything
1598    def test_trace_optional(self):
1599        @torch.jit.script
1600        def test(x: Optional[Tensor]):
1601            if x is None:
1602                return torch.zeros(1)
1603            else:
1604                return x
1605
1606        def test_none():
1607            return test(None)
1608
1609        def test_tensor():
1610            return test(torch.zeros(2))
1611
1612        f_none = torch.jit.trace(test_none, ())
1613        self.assertEqual(f_none(), torch.zeros(1))
1614
1615        f_tensor = torch.jit.trace(test_tensor, ())
1616        self.assertEqual(f_tensor(), torch.zeros(2))
1617
1618        graph = f_tensor.graph
1619        FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph)
1620
1621    def test_trace_nested_datatypes(self):
1622        @torch.jit.script
1623        def foo(x):
1624            return [[x + 1, x - 1], [x + 2, x - 2]]
1625
1626        def bar(x):
1627            list_stuff = foo(x)
1628            return list_stuff[0][0], list_stuff[1][1]
1629
1630        traced = torch.jit.trace(bar, torch.rand(3, 4))
1631        x = torch.rand(5, 6)
1632        self.assertEqual(bar(x), traced(x))
1633
1634    @_tmp_donotuse_dont_inline_everything
1635    def test_call_traced_fn_from_traced_module(self):
1636        @_trace(torch.rand(3, 4))
1637        def traced_fn(x):
1638            return torch.neg(x)
1639
1640        class TracedModule(torch.nn.Module):
1641            def __init__(self) -> None:
1642                super().__init__()
1643                self.param = torch.nn.Parameter(torch.rand(4, 5))
1644
1645            def forward(self, x):
1646                return traced_fn(torch.mm(x, self.param))
1647
1648        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1649
1650        # Note: neg op from the traced function should be properly inlined
1651        FileCheck().check("aten::mm").check('name="traced_fn"').check_next(
1652            "prim::CallFunction"
1653        ).run(str(tm.graph))
1654
1655    @_tmp_donotuse_dont_inline_everything
1656    def test_call_traced_module_from_traced_module(self):
1657        class TracedModule1(torch.nn.Module):
1658            def __init__(self) -> None:
1659                super().__init__()
1660                self.param = torch.nn.Parameter(torch.rand(5, 7))
1661
1662            def forward(self, x):
1663                return torch.mm(x, self.param)
1664
1665        class TracedModule(torch.nn.Module):
1666            def __init__(self) -> None:
1667                super().__init__()
1668                self.param = torch.nn.Parameter(torch.rand(4, 5))
1669                self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
1670
1671            def forward(self, x):
1672                return self.mod(torch.mm(x, self.param)) + 1.0
1673
1674        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1675
1676        FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
1677            "forward"
1678        ).check("aten::add").run(str(tm.graph))
1679
1680    def test_index_put_trace_with_view(self):
1681        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
1682        def test_index_put(target, indices, rhs):
1683            target[indices] = rhs
1684            return target
1685
1686        FileCheck().check("aten::view").check("index_put_").run(
1687            str(test_index_put.graph)
1688        )
1689
1690    def test_index_put_trace_without_view(self):
1691        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
1692        def test_index_put(target, indices, rhs):
1693            target[indices] = rhs
1694            return target
1695
1696        FileCheck().check_not("aten::view").check("index_put_").run(
1697            str(test_index_put.graph)
1698        )
1699
1700    @suppress_warnings
1701    def test_trace_checker_dot_data(self):
1702        with self.assertRaisesRegex(
1703            torch.jit.TracingCheckError,
1704            r"Tensor-valued Constant nodes differed in value " r"across invocations",
1705        ):
1706
1707            @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
1708            def foo(x):
1709                y = x.data
1710                return x + y
1711
1712    @suppress_warnings
1713    def test_trace_checker_control_flow(self):
1714        def foo(x):
1715            for _ in range(x.size(0)):
1716                x = torch.neg(x)
1717            return x
1718
1719        with self.assertRaisesRegex(
1720            torch.jit.TracingCheckError, r"Graphs differed across invocations!"
1721        ):
1722            torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
1723
1724    @suppress_warnings
1725    def test_trace_checker_memoization(self):
1726        with self.assertRaisesRegex(
1727            torch.jit.TracingCheckError, r"Graphs differed across invocations!"
1728        ):
1729
1730            def foo(x):
1731                if not hasattr(foo, "cache"):
1732                    foo.cache = torch.neg(x)
1733                return x + foo.cache
1734
1735            traced = torch.jit.trace(
1736                foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]
1737            )
1738
1739    def test_trace_checker_slice_lhs(self):
1740        def foo(x):
1741            for i in range(3):
1742                x[i, :] = torch.zeros(4)
1743            return x
1744
1745        self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False)
1746
1747    def test_trace_checker_inplace_on_view(self):
1748        def foo(x):
1749            x.view(-1).add_(-x.view(-1))
1750            return x
1751
1752        with self.assertWarnsRegex(
1753            torch.jit.TracerWarning,
1754            "Output nr 1. of the traced function does not match the "
1755            "corresponding output of the Python function",
1756        ):
1757            torch.jit.trace(
1758                foo,
1759                torch.rand(3, 4),
1760                check_inputs=[torch.rand(5, 6)],
1761                _force_outplace=True,
1762            )
1763
1764    def test_lhs_index_fails(self):
1765        def foo(x):
1766            x[0, 1] = 4
1767            return x
1768
1769        with self.assertWarnsRegex(
1770            torch.jit.TracerWarning, "cause the trace to be incorrect"
1771        ):
1772            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
1773
1774    def test_lhs_index_trivial(self):
1775        def foo(y, x):
1776            y[...] = x
1777            return y
1778
1779        self.checkTrace(
1780            foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False
1781        )
1782
1783    def test_inplace_warn(self):
1784        def foo(x):
1785            x.view(-1).add_(-x.view(-1))
1786            return x
1787
1788        with self.assertWarnsRegex(
1789            torch.jit.TracerWarning, "cause the trace to be incorrect"
1790        ):
1791            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
1792
1793    @suppress_warnings
1794    def test_trace_checker_dropout_train(self):
1795        def foo(x):
1796            return torch.dropout(x, p=0.5, train=True)
1797
1798        with self.assertWarnsRegex(
1799            torch.jit.TracerWarning,
1800            "Output nr 1. of the traced function does not match the "
1801            "corresponding output of the Python function",
1802        ):
1803            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
1804
1805        with self.assertWarnsRegex(
1806            torch.jit.TracerWarning, "Trace had nondeterministic nodes"
1807        ):
1808            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
1809
1810    def test_trace_checker_dropout_notrain(self):
1811        input = torch.rand(3, 4)
1812
1813        @_trace(input)
1814        def foo(x):
1815            return torch.dropout(x, p=0.5, train=False)
1816
1817        self.assertEqual(foo(input), input)
1818
1819    def test_trace_contiguous(self):
1820        def foo(x):
1821            return x[:, :, ::2].contiguous().view(12)
1822
1823        x = torch.rand(2, 3, 4)
1824        traced = torch.jit.trace(foo, (x,))
1825        y = traced(x)
1826        self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
1827
1828    # This tests the logic in THPVariable_contiguous. There is short-circuiting
1829    # code that prevents us from even getting to VariableType::contiguous, since
1830    # it is an optimization that prevents us from acquiring the GIL for touching
1831    # the device. We needed to add the tracing logic directly into the
1832    # THPVariable_contiguous function only for the path where we are skipping
1833    # dispatch into contiguous. We should see an aten::contiguous in this trace!
1834    def test_trace_contiguous_short_circuit(self):
1835        def foo(x):
1836            return x.contiguous()
1837
1838        x = torch.rand(2, 3, 4)
1839        traced = torch.jit.trace(foo, (x,))
1840        FileCheck().check("aten::contiguous").run(str(traced.graph))
1841
1842    def test_trace_inverse(self):
1843        def foo(x):
1844            return ~x
1845
1846        foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8))
1847        eg = torch.zeros(3, dtype=torch.uint8)
1848        self.assertEqual(foo_traced(eg), foo(eg))
1849
1850    def test_trace_modulelist(self):
1851        class MySubmod(torch.nn.Module):
1852            def __init__(self) -> None:
1853                super().__init__()
1854                self.relu = torch.nn.ReLU()
1855
1856            def forward(self, x):
1857                return self.relu(x)
1858
1859        class MyMod(torch.nn.Module):
1860            def __init__(self) -> None:
1861                super().__init__()
1862                self.ml = torch.nn.ModuleList([MySubmod(), MySubmod()])
1863
1864            def forward(self, x):
1865                for mod in self.ml:
1866                    x = mod(x)
1867                return x
1868
1869        traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),))
1870
1871    def test_trace_fork_join_and_module(self):
1872        class MySubmod(torch.nn.Module):
1873            def __init__(self) -> None:
1874                super().__init__()
1875                self.relu = torch.nn.ReLU()
1876
1877            def forward(self, x):
1878                return self.relu(x), torch.neg(x)
1879
1880        class Mod(torch.nn.Module):
1881            def __init__(self) -> None:
1882                super().__init__()
1883                self.ml = torch.nn.ModuleList([MySubmod() for i in range(2)])
1884
1885            def forward(self, x):
1886                futs = []
1887                for i in range(2):
1888                    futs.append(torch.jit._fork(self.ml[i], x))
1889
1890                results = []
1891                for i in range(2):
1892                    results.append(torch.jit._wait(futs[i])[0])
1893
1894                return torch.stack(results)
1895
1896        m = Mod()
1897        traced = torch.jit.trace(m, torch.rand(3, 4))
1898
1899    def test_trace_invert_module_hierarchy(self):
1900        class MySubmod(torch.nn.Module):
1901            def __init__(self) -> None:
1902                super().__init__()
1903                self.relu = torch.nn.ReLU()
1904
1905            def forward(self, x):
1906                return self.relu(x), torch.neg(x)
1907
1908        class MyFunctionalMod(torch.nn.Module):
1909            def forward(self, x, submod):
1910                return submod(x)
1911
1912        class Mod(torch.nn.Module):
1913            def __init__(self) -> None:
1914                super().__init__()
1915                self.sm = MySubmod()
1916                self.fm = MyFunctionalMod()
1917
1918            def forward(self, x):
1919                return self.fm(x, self.sm)
1920
1921        torch.jit.trace(Mod(), (torch.rand(3, 4),))
1922
1923    @skipIfCrossRef
1924    def test_trace_records_names(self):
1925        def foo(bar, baz):
1926            baz = bar + 3
1927            quick_brown_fox = torch.neg(baz)
1928            for _ in range(20):
1929                yeet = quick_brown_fox - 3.14
1930            return yeet
1931
1932        traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
1933        graph_str = str(traced.graph)
1934        assert "bar" in graph_str
1935        assert "baz" in graph_str
1936        assert "quick_brown_fox" in graph_str
1937
1938    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1939    def test_tracing_hooks(self):
1940        class Net(nn.Module):
1941            def forward(self, x):
1942                return x + x
1943
1944        def test_hook(is_post_hook, hook, fc):
1945            n = Net()
1946            if is_post_hook:
1947                n.register_forward_hook(hook)
1948            else:
1949                n.register_forward_pre_hook(hook)
1950
1951            module = torch.jit.trace(n, (torch.tensor(1.0),))
1952
1953            eager_input = torch.tensor(1.0)
1954            eager_out = n(eager_input)
1955
1956            fc.run(module.forward.graph)
1957            input = torch.tensor(1.0)
1958            output = module(input)
1959
1960            self.assertEqual(input, eager_input)
1961            self.assertEqual(output, eager_out)
1962
1963        def hook_no_return(mod, input, output):
1964            input[0].add_(1)
1965            output.sub_(1)
1966
1967        fc = FileCheck().check("add(").check("add_(").check("sub_(")
1968        test_hook(True, hook_no_return, fc)
1969
1970        def hook_return(mod, input, output):
1971            input[0].add_(1)
1972            return output - 3
1973
1974        fc = FileCheck().check("add(").check("add_(").check("sub(")
1975        test_hook(True, hook_return, fc)
1976
1977        b = torch.tensor(3.0)
1978
1979        def captured_hook(mod, input, output):
1980            return output - b
1981
1982        fc = FileCheck().check("add(").check("sub(")
1983        test_hook(True, captured_hook, fc)
1984
1985        def pre_hook_no_ret(mod, input):
1986            input[0].add_(3)
1987
1988        fc = FileCheck().check("add_(").check("add(")
1989        test_hook(False, pre_hook_no_ret, fc)
1990
1991        def pre_hook_ret(mod, input):
1992            return input[0] - 4
1993
1994        fc = FileCheck().check("sub(").check("add(")
1995        test_hook(False, pre_hook_ret, fc)
1996
1997    def test_tracing_backward_hook_error(self):
1998        class Net(nn.Module):
1999            def forward(self, x):
2000                return x + x
2001
2002        n = Net()
2003
2004        def backward_hook(module, grad_input, grad_output):
2005            pass
2006
2007        n.register_backward_hook(backward_hook)
2008        with self.assertRaisesRegex(Exception, "backward hooks assigned"):
2009            torch.jit.trace(n, (torch.tensor(1.0),))
2010
2011    def test_tracing_multiple_methods(self):
2012        class Net(nn.Module):
2013            def __init__(self) -> None:
2014                super().__init__()
2015                self.conv = nn.Conv2d(1, 1, 3)
2016
2017            def forward(self, x):
2018                return self.conv(x)
2019
2020            def weighted_kernel_sum(self, weight):
2021                return weight * self.conv.weight
2022
2023        example_weight = torch.rand(1, 1, 3, 3)
2024        example_forward_input = torch.rand(1, 1, 3, 3)
2025        inputs = {
2026            "forward": example_forward_input,
2027            "weighted_kernel_sum": example_weight,
2028        }
2029        n = Net()
2030        module = torch.jit.trace_module(n, inputs)
2031
2032        check_inputs = []
2033        for i in range(2):
2034            check_weight = torch.rand(1, 1, 3, 3)
2035            check_forward_input = torch.rand(1, 1, 3, 3)
2036            check_inputs.append(
2037                {"forward": check_forward_input, "weighted_kernel_sum": check_weight}
2038            )
2039        module = torch.jit.trace_module(
2040            n, inputs, check_trace=True, check_inputs=check_inputs
2041        )
2042        self.assertTrue(module._c._has_method("forward"))
2043        self.assertTrue(module._c._has_method("weighted_kernel_sum"))
2044
2045        module = torch.jit.trace(n.forward, example_forward_input)
2046        module = torch.jit.trace(
2047            n.forward,
2048            example_forward_input,
2049            check_trace=True,
2050            check_inputs=[example_forward_input],
2051        )
2052        with self.assertRaisesRegex(
2053            AttributeError,
2054            "trace doesn't support compiling individual module's functions",
2055        ):
2056            module = torch.jit.trace(n.weighted_kernel_sum, inputs)
2057
2058    def test_tensor_with_grad_as_constant(self):
2059        param = torch.randn(3).requires_grad_()
2060        x = torch.randn(3)
2061
2062        def f(x):
2063            return x + param
2064
2065        with self.assertRaisesRegex(
2066            RuntimeError, "Cannot insert a Tensor that requires grad as a constant"
2067        ):
2068            torch.jit.trace(f, x)
2069
2070    def test_non_tensor_tracing(self):
2071        def f(x):
2072            return x + param  # noqa: F821
2073
2074        with self.assertRaisesRegex(
2075            RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"
2076        ):
2077            torch.jit.trace(f, (1,))
2078
2079    def test_trace_skip_none_submodule(self):
2080        class TestModule(torch.nn.Module):
2081            def __init__(self) -> None:
2082                super().__init__()
2083                self.submod = torch.nn.Linear(3, 4)
2084                self.submod = None
2085
2086            def forward(self, inputs):
2087                return inputs
2088
2089        m = TestModule()
2090        tm = torch.jit.trace(m, torch.tensor(1.0))
2091        self.assertFalse(hasattr(tm, "submod"))
2092
2093    def test_trace_with_conditional_property(self):
2094        class Net(nn.Module):
2095            def __init__(self, attr=None):
2096                super().__init__()
2097                if attr is not None:
2098                    self._attr = attr
2099                self.attr_name = "_attr"
2100
2101            @property
2102            def attr(self):
2103                return getattr(self, self.attr_name)
2104
2105            def forward(self, x):
2106                return x
2107
2108        x = torch.ones(1)
2109        torch.jit.trace(Net(), x)
2110
2111    def test_trace_func_argument_names_captured(self):
2112        def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
2113            return first_arg + second_arg
2114
2115        traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1)))
2116        FileCheck().check("first_arg").check_next("second_arg").run(
2117            str(traced_fn.graph)
2118        )
2119
2120    def test_trace_partial_func_argument_names_captured(self):
2121        def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor:
2122            return first_arg + second_arg
2123
2124        traced_fn = torch.jit.trace(fn, (torch.ones(1),))
2125        FileCheck().check("first_arg").check_not("second_arg").run(str(traced_fn.graph))
2126
2127    def test_trace_module_argument_names_captured(self):
2128        class TestModule(nn.Module):
2129            def __init__(self) -> None:
2130                super().__init__()
2131                self.conv = nn.Conv2d(1, 1, 3)
2132
2133            def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor):
2134                return self.conv(first_arg) + second_arg
2135
2136        m = TestModule()
2137        example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3))
2138
2139        # Explicitly tracing module's forward method
2140        traced_module_forward = torch.jit.trace(m.forward, example_input)
2141        FileCheck().check("first_arg").check_next("second_arg").run(
2142            str(traced_module_forward.graph)
2143        )
2144
2145        # Tracing module's directly
2146        traced_module = torch.jit.trace(m, example_input)
2147        FileCheck().check("first_arg").check_next("second_arg").run(
2148            str(traced_module.graph)
2149        )
2150
2151    def test_trace_checking_with_deprecated_name(self):
2152        class MyClass(torch.nn.Module):
2153            def __init__(self) -> None:
2154                super(MyClass, self).__init__()
2155
2156            def forward(self, x, y, **deprecated_arguments):
2157                if len(deprecated_arguments) > 0:
2158                    raise RuntimeError(
2159                        f"Got unexpected arguments: {deprecated_arguments}"
2160                    )
2161                return x + y
2162
2163        model = MyClass()
2164        m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
2165        m3 = torch.jit.trace(
2166            model,
2167            example_kwarg_inputs={"x": torch.ones(1), "y": torch.ones(1)},
2168            strict=False,
2169        )
2170
2171    def test_trace_with_tuple_tensor(self):
2172        class MyClass(torch.nn.Module):
2173            def __init__(self) -> None:
2174                super(MyClass, self).__init__()
2175
2176            def forward(self, x, y):
2177                return x + y[0] + y[1]
2178
2179        model = MyClass()
2180        traced_model = torch.jit.trace(
2181            model, (torch.ones(1), (torch.ones(1), torch.ones(1)))
2182        )
2183        input_dict = {
2184            "x": torch.tensor([2, 3]),
2185            "y": (torch.tensor([5, 6]), torch.tensor([7, 8])),
2186        }
2187        self.assertEqual(model(**input_dict), traced_model(**input_dict))
2188        traced_model = torch.jit.trace(
2189            model,
2190            example_kwarg_inputs={
2191                "x": torch.ones(1),
2192                "y": (torch.ones(1), torch.ones(1)),
2193            },
2194        )
2195        self.assertEqual(model(**input_dict), traced_model(**input_dict))
2196
2197    def test_trace_no_duplicated_lifted_input_output(self):
2198        class Normalize(nn.Module):
2199            def __init__(self) -> None:
2200                super().__init__()
2201                self.norm = nn.GroupNorm(num_groups=32, num_channels=32)
2202
2203            def forward(self, x, y):
2204                if y is None:
2205                    y = x
2206                else:
2207                    y = self.norm(y)
2208                y = y * 2
2209                return y
2210
2211        class G(nn.Module):
2212            def __init__(self) -> None:
2213                super().__init__()
2214                self.norm = Normalize()
2215
2216            def forward(self, x):
2217                A = self.norm(x, None)
2218                B = F.relu(A)
2219                return A, B
2220
2221        class Net(nn.Module):
2222            def __init__(self) -> None:
2223                super().__init__()
2224                self.g = G()
2225                self.norm_1 = Normalize()
2226
2227            def forward(self, x):
2228                hs = self.g(x)
2229                A, B = hs
2230                h = self.norm_1(B, A)
2231                return h
2232
2233        net = Net()
2234        net = net.eval()
2235        x = torch.randn(1, 32, 16, 16)
2236        traced = torch.jit.trace(net, x)
2237        FileCheck().check_not("prim::TupleUnpack").run(str(traced.graph))
2238
2239
2240@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
2241class TestMixTracingScripting(JitTestCase):
2242    def test_trace_script(self):
2243        @torch.jit.script
2244        def func1(x: Tuple[Tensor, Tensor]) -> Tensor:
2245            return x[0] + x[1]
2246
2247        @torch.jit.script
2248        def func2(x: List[Tensor]) -> Tensor:
2249            return x[0] + x[1]
2250
2251        a = torch.randn(5)
2252        b = torch.randn(5)
2253
2254        self.checkTrace(func1, ((a, b),))
2255        self.checkTrace(func2, ((a, b),))
2256
2257        @torch.jit.script
2258        def func3(
2259            x: Tensor, method: str = "bilinear", align_corners: bool = True
2260        ) -> Tensor:
2261            hw = x.shape[2:4]
2262            return F.interpolate(x, hw, mode=method, align_corners=align_corners)
2263
2264        inp = torch.rand(1, 3, 6, 6)
2265        self.checkTrace(func3, (inp,))
2266
2267        @torch.jit.script
2268        def func4(x: Tensor, a: List[Optional[str]]) -> Tensor:
2269            if len(a) == 2:
2270                return x + 2
2271            else:
2272                return x
2273
2274    def test_trace_mixed_by_script_with_dict_output(self):
2275        @torch.jit.script
2276        def return_dict(input: torch.Tensor) -> Dict[str, torch.Tensor]:
2277            return {"foo": input + 1}
2278
2279        class TraceModule(torch.nn.Module):
2280            def forward(self, input):
2281                dict = return_dict(input)
2282                return dict["foo"] + dict["foo"]
2283
2284        x = torch.ones(1)
2285        tm = torch.jit.trace(TraceModule(), x)
2286        self.assertEqual(tm(x), x + 1 + x + 1)
2287
2288    def test_trace_of_script(self):
2289        @torch.jit.script
2290        def foo(a, c):
2291            b = 0.0
2292            if bool(a == 0.0):
2293                b = 1.0
2294            return b + c
2295
2296        a = torch.ones(1, dtype=torch.float)
2297
2298        @_trace(torch.zeros(1, dtype=torch.float))
2299        def use(b):
2300            return foo(b - 1.0, a) + 1.0
2301
2302        # test we propagated shapes through the function
2303        self.assertTrue("Dynamic" not in str(use.graph))
2304
2305        self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
2306        self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
2307
2308    def test_trace_with_size(self):
2309        @_trace(torch.zeros(1, 1))
2310        def foo(x):
2311            return x + 1
2312
2313        @torch.jit.script
2314        def bar(x):
2315            y = int(foo(x))
2316            if 1 == 1:
2317                y = 7
2318            return y + 1
2319
2320        self.assertEqual(8, bar(torch.ones(1, 1)))
2321
2322    def test_tracing_slicing(self):
2323        @_trace(torch.zeros(10))
2324        def foo_trace(x):
2325            return x[-5:-3]
2326
2327        @torch.jit.script
2328        def foo_script(x):
2329            return x[-5:-3]
2330
2331        def foo(x):
2332            return x[-5:-3]
2333
2334        a = torch.arange(0, 8)
2335        b = torch.arange(0, 20)
2336        self.assertEqual(foo_trace(a), foo_script(a))
2337        self.assertEqual(foo_trace(a), foo(a))
2338        self.assertNotEqual(foo_trace(a), foo_trace(b))
2339
2340    def test_tracing_indexing(self):
2341        @_trace(torch.zeros(10))
2342        def foo_trace(x):
2343            return x[-2]
2344
2345        @torch.jit.script
2346        def foo_script(x):
2347            return x[-2]
2348
2349        def foo(x):
2350            return x[-2]
2351
2352        a = torch.arange(0, 8)
2353        b = torch.arange(0, 20)
2354        self.assertEqual(foo_script(a), foo_trace(a))
2355        self.assertEqual(foo_trace(a), foo(a))
2356        self.assertNotEqual(foo_trace(a), foo_trace(b))
2357
2358    def test_trace_hierarchy(self):
2359        # Test that we preserve the module hierarchy for a ScriptModule
2360        # submodule during tracing
2361
2362        class AnotherScriptMod(torch.jit.ScriptModule):
2363            def __init__(self) -> None:
2364                super().__init__()
2365                self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
2366
2367            @torch.jit.script_method
2368            def bar(self):
2369                return torch.zeros(4, 5)
2370
2371        class SomeScriptMod(torch.jit.ScriptModule):
2372            def __init__(self) -> None:
2373                super().__init__()
2374                self.asm = AnotherScriptMod()
2375
2376            @torch.jit.script_method
2377            def foo(self):
2378                return torch.zeros(3, 4)
2379
2380            @torch.jit.script_method
2381            def bar(self):
2382                return torch.zeros(4, 3)
2383
2384        class TraceMe(torch.nn.Module):
2385            def __init__(self) -> None:
2386                super().__init__()
2387                self.ssm = SomeScriptMod()
2388
2389            def forward(self, x):
2390                return self.ssm.bar() + x
2391
2392        orig = TraceMe()
2393        traced = torch.jit.trace(orig, (torch.rand(4, 3),))
2394        # for each of these checks, check that *BOTH* the underlying
2395        # _C.ScriptModule object has the expected method/param, as well as the
2396        # Python object that wraps it.
2397        self.assertTrue(traced.ssm._c._has_method("foo"))
2398        self.assertTrue(hasattr(traced.ssm, "foo"))
2399
2400        imported = self.getExportImportCopy(traced)
2401
2402        self.assertTrue(imported.ssm._c._has_method("foo"))
2403        self.assertTrue(hasattr(imported.ssm, "foo"))
2404
2405        self.assertTrue(imported.ssm.asm._c._has_method("bar"))
2406        self.assertTrue(hasattr(imported.ssm.asm, "bar"))
2407
2408        self.assertTrue(hasattr(imported.ssm.asm, "param"))
2409
2410    def test_trace_parameter(self):
2411        class Param(nn.Module):
2412            def __init__(self) -> None:
2413                super().__init__()
2414                self.register_parameter("bias", nn.Parameter(torch.empty(4, 4)))
2415
2416            def forward(self, x):
2417                return x
2418
2419        class M3(torch.jit.ScriptModule):
2420            def __init__(self, model):
2421                super().__init__()
2422                self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
2423
2424            @torch.jit.script_method
2425            def forward(self, x):
2426                return self.traced(x)
2427
2428        class M2(nn.Module):
2429            def __init__(self, model):
2430                super().__init__()
2431                self.module = M3(model)
2432
2433            def forward(self, x):
2434                return self.module(x)
2435
2436        class M1(torch.jit.ScriptModule):
2437            def __init__(self, model):
2438                super().__init__()
2439                self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
2440
2441            @torch.jit.script_method
2442            def forward(self, x):
2443                return self.traced(x)
2444
2445        with torch.jit.optimized_execution(False):
2446            module = M1(Param())
2447            f = io.BytesIO()
2448            torch.jit.save(module, f)
2449
2450    @_tmp_donotuse_dont_inline_everything
2451    def test_call_script_fn_from_traced_module(self):
2452        @torch.jit.script
2453        def scripted_fn(x):
2454            return torch.neg(x)
2455
2456        class TracedModule(torch.nn.Module):
2457            def __init__(self) -> None:
2458                super().__init__()
2459                self.param = torch.nn.Parameter(torch.rand(4, 5))
2460
2461            def forward(self, x):
2462                return scripted_fn(torch.mm(x, self.param))
2463
2464        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
2465        FileCheck().check("aten::mm").check('name="scripted_fn"').check(
2466            "prim::CallFunction"
2467        ).run(str(tm.graph))
2468
2469    @_tmp_donotuse_dont_inline_everything
2470    def test_call_script_module_from_traced_module(self):
2471        class ScriptMod(torch.jit.ScriptModule):
2472            def __init__(self) -> None:
2473                super().__init__()
2474                self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
2475
2476            @torch.jit.script_method
2477            def forward(self, x):
2478                return torch.mm(x, self.param_foo)
2479
2480        class TracedModule(torch.nn.Module):
2481            def __init__(self) -> None:
2482                super().__init__()
2483                self.param = torch.nn.Parameter(torch.rand(4, 5))
2484                self.mod = ScriptMod()
2485
2486            def forward(self, x):
2487                return self.mod(torch.mm(x, self.param)) + 1.0
2488
2489        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
2490
2491        FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
2492            "forward"
2493        ).check("aten::add").run(str(tm.graph))
2494
2495    @_tmp_donotuse_dont_inline_everything
2496    def test_call_traced_fn_from_script_fn(self):
2497        @_trace(torch.rand(3, 4))
2498        def traced_fn(x):
2499            return torch.neg(x)
2500
2501        @torch.jit.script
2502        def script_fn(x):
2503            return traced_fn(x) + 1
2504
2505        FileCheck().check("prim::CallFunction").check("aten::add").run(
2506            str(script_fn.graph)
2507        )
2508
2509    def test_call_traced_mod_from_script_fn(self):
2510        with self.assertRaisesRegex(
2511            RuntimeError,
2512            "Cannot call a ScriptModule that is not a submodule of the caller",
2513        ):
2514
2515            class TracedModule(torch.nn.Module):
2516                def forward(self, x):
2517                    return torch.mm(x, torch.zeros(4, 3))
2518
2519            tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
2520
2521            @torch.jit.script
2522            def script_fn(x):
2523                return tm(x) + 1
2524
2525    @_tmp_donotuse_dont_inline_everything
2526    def test_call_tracing_fn_from_script_module(self):
2527        @_trace(torch.rand(3, 3))
2528        def traced_fn(x):
2529            return torch.neg(x)
2530
2531        class ScriptMod(torch.jit.ScriptModule):
2532            def __init__(self) -> None:
2533                super().__init__()
2534                self.param = torch.nn.Parameter(torch.rand(4, 3))
2535
2536            @torch.jit.script_method
2537            def forward(self, x):
2538                return traced_fn(torch.mm(x, self.param))
2539
2540        sm = ScriptMod()
2541        FileCheck().check("aten::mm").check("prim::CallFunction").run(
2542            str(sm.forward.graph)
2543        )
2544
2545    @_tmp_donotuse_dont_inline_everything
2546    def test_call_tracing_mod_from_script_module(self):
2547        class TracedMod(torch.nn.Module):
2548            def __init__(self) -> None:
2549                super().__init__()
2550                self.param = torch.nn.Parameter(torch.rand(3, 5))
2551
2552            def forward(self, x):
2553                return torch.mm(x, self.param)
2554
2555        class ScriptMod(torch.jit.ScriptModule):
2556            def __init__(self) -> None:
2557                super().__init__()
2558                self.param = torch.nn.Parameter(torch.rand(4, 3))
2559                self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
2560
2561            @torch.jit.script_method
2562            def forward(self, x):
2563                return self.tm(torch.mm(x, self.param))
2564
2565        sm = ScriptMod()
2566        FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph))
2567
2568    def test_script_inline_trace_multiple_args(self):
2569        class M(torch.nn.Module):
2570            def forward(self, input, input2):
2571                return input + input2
2572
2573        class M2(torch.jit.ScriptModule):
2574            def __init__(self) -> None:
2575                super().__init__()
2576                self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))
2577
2578            @torch.jit.script_method
2579            def forward(self, inp):
2580                return self.m(inp, inp)
2581
2582        with torch.jit.optimized_execution(False):
2583            m2 = M2()
2584            m2(torch.zeros(4, 3))
2585
2586    def test_trace_dict_mix_script(self):
2587        class testB(torch.nn.Module):
2588            def __init__(self) -> None:
2589                super().__init__()
2590                self.linear = torch.nn.Linear(2, 2)
2591
2592            def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor:
2593                output = []
2594                for j in feature_map.values():
2595                    output.append(self.linear(j[0]))
2596
2597                return torch.stack(output)
2598
2599        class testA(torch.nn.Module):
2600            def __init__(self) -> None:
2601                super().__init__()
2602                self.b = torch.jit.script(testB())
2603
2604            def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor:
2605                feature_map = {}
2606                for i, j in input_map.items():
2607                    feature_map[i] = [j[0]]
2608
2609                return self.b(feature_map)
2610
2611        input_map = {
2612            "1": [torch.rand(2, 2), torch.rand(2, 2)],
2613            "3": [torch.rand(2, 2), torch.rand(2, 2)],
2614        }
2615        model = testA()
2616        traced_model = torch.jit.trace(model, input_map)
2617        new_input_map = {
2618            "1": [torch.rand(2, 2), torch.randn(2, 2)],
2619            "3": [torch.rand(2, 2), torch.rand(2, 2)],
2620        }
2621        self.assertEqual(model(new_input_map), traced_model(new_input_map))
2622
2623    def test_trace_script_returning_complex_dict(self):
2624        """Tracing over a script function returning a dictionary should work.
2625        The dictionary can should be able to contain other containers (like a tuple) recursively.
2626        """
2627
2628        class ReturnsDict(torch.nn.Module):
2629            def forward(
2630                self,
2631                id_score_list: Dict[
2632                    str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
2633                ],
2634            ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
2635                # do some random operations and then return a dict of the same structure
2636                v = id_score_list["1000"]
2637                idx_keys = v[1] - 1500000
2638                weights = v[2]
2639                result = {"1000": (v[0], idx_keys, weights)}
2640                return result
2641
2642        class ChecksDict(torch.nn.Module):
2643            def forward(
2644                self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
2645            ):
2646                v = input["1000"]
2647                return v[1] + 1
2648
2649        class TestModule(torch.nn.Module):
2650            def __init__(self, checks_dict, returns_dict):
2651                super().__init__()
2652                self.checks_dict = checks_dict
2653                self.returns_dict = returns_dict
2654
2655            def forward(
2656                self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
2657            ):
2658                foo = self.returns_dict(input)
2659                return self.checks_dict(foo)
2660
2661        input1 = {
2662            "1000": (
2663                torch.tensor([0]),
2664                torch.tensor([], dtype=torch.int64),
2665                torch.tensor([]),
2666            )
2667        }
2668
2669        input2 = {
2670            "1000": (
2671                torch.tensor([0]),
2672                torch.tensor([1500000, 1500004], dtype=torch.int64),
2673                torch.tensor([2.0, 3.0]),
2674            )
2675        }
2676
2677        checks_dict = torch.jit.script(ChecksDict())
2678        returns_dict = torch.jit.script(ReturnsDict())
2679        eager_module = TestModule(checks_dict, returns_dict)
2680        traced_module = torch.jit.trace(eager_module, input1)
2681        self.assertEqual(traced_module(input1), eager_module(input1))
2682        self.assertEqual(traced_module(input2), eager_module(input2))
2683
2684    def test_trace_returning_dict_with_tensor_tuples(self):
2685        """Tracing over a module returning a dictionary whose values are tuples of tensors
2686        should work.
2687        """
2688
2689        class ReturnsDict(torch.nn.Module):
2690            def forward(
2691                self, k: torch.Tensor, v: torch.Tensor
2692            ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
2693                x = 2 * k
2694                y = 3 * v
2695                result = {"imakey": (x, y)}
2696                return result
2697
2698        class ReturnsBadDict(torch.nn.Module):
2699            def forward(
2700                self, k: torch.Tensor, v: torch.Tensor
2701            ) -> Dict[str, Tuple[torch.Tensor, float]]:
2702                x = 2 * k
2703                result = {"imakey": (x, 1)}
2704                return result
2705
2706        mod = ReturnsDict()
2707        traced_module = torch.jit.trace(
2708            mod, [torch.ones(1), torch.ones(1)], strict=False
2709        )
2710        out = traced_module(torch.ones(1), torch.ones(1))
2711        expected = {"imakey": (torch.tensor([2.0]), torch.tensor([3.0]))}
2712        self.assertEqual(out, expected)
2713
2714        with self.assertRaisesRegex(
2715            RuntimeError, "cannot be understood by the tracer, only outputs matching"
2716        ):
2717            mod = ReturnsBadDict()
2718            traced_module = torch.jit.trace(
2719                mod, [torch.ones(1), torch.ones(1)], strict=False
2720            )
2721
2722    def test_trace_linear(self):
2723        m = torch.nn.Linear(20, 20)
2724        inp = torch.rand([20, 20])
2725        self.checkTrace(m, (inp,))
2726        g = torch.jit.trace(m, (inp,)).graph
2727        FileCheck().check("aten::linear").run(g)
2728
2729    def test_traced_module_implements_interface(self):
2730        @torch.jit.interface
2731        class TestModuleInterface(nn.Module):
2732            def forward(
2733                self, first_arg: torch.Tensor, second_arg: torch.Tensor
2734            ) -> torch.Tensor:
2735                pass
2736
2737        make_global(TestModuleInterface)
2738
2739        class TestModule(nn.Module):
2740            def __init__(self) -> None:
2741                super().__init__()
2742                self.conv = nn.Conv2d(1, 1, 3)
2743
2744            def forward(
2745                self, first_arg: torch.Tensor, second_arg: torch.Tensor
2746            ) -> torch.Tensor:
2747                return self.conv(first_arg) + second_arg
2748
2749        def fn_takes_interface(x: TestModuleInterface):
2750            ones = torch.ones(1, 1, 3, 3)
2751            return x.forward(ones, ones)
2752
2753        scripted_test_module = torch.jit.script(TestModule())
2754        self.checkScript(fn_takes_interface, (scripted_test_module,))
2755
2756    def test_traced_module_contains_scripted_interface_types(self):
2757        class LeafModule(torch.nn.Module):
2758            def __init__(self) -> None:
2759                super().__init__()
2760                self.weight = torch.nn.Parameter(torch.rand(19))
2761
2762            def forward(self, input: torch.Tensor):
2763                return input + self.weight
2764
2765        class LowerModuleImpl(torch.nn.Module):
2766            def __init__(self) -> None:
2767                super().__init__()
2768                self.leaf = LeafModule()
2769
2770            def forward(self, input: torch.Tensor) -> torch.Tensor:
2771                return self.leaf(input)
2772
2773        @torch.jit.interface
2774        class LowerModuleInterface(torch.nn.Module):
2775            def forward(self, input: torch.Tensor) -> torch.Tensor:
2776                pass
2777
2778        class MiddleModule(torch.nn.Module):
2779            lower: LowerModuleInterface
2780
2781            def __init__(self, feature_processor_modules=None):
2782                super().__init__()
2783                self.lower = LowerModuleImpl()
2784
2785            def forward(self, input):
2786                return self.lower(input)
2787
2788        class WrapperModule(torch.nn.Module):
2789            def __init__(self, m):
2790                super().__init__()
2791                self.middle = m
2792
2793            def forward(self, input):
2794                return self.middle(input)
2795
2796        class TopModule(torch.nn.Module):
2797            def __init__(self) -> None:
2798                super().__init__()
2799                m = MiddleModule()
2800                m = torch.jit.script(m)
2801                self.sub1 = m
2802                self.sub2 = WrapperModule(m)
2803
2804            def forward(self, input: torch.Tensor):
2805                return self.sub1(input) + self.sub2(input)
2806
2807        top = TopModule()
2808        top_example_input = torch.ones(1)
2809        torch.jit.trace(top, top_example_input)
2810
2811    def test_jit_trace_callfunction_return_shapes(self):
2812        # a torch.jit.script function gets inserted as a CallFunction node
2813        @torch.jit.script
2814        def inner_fn(x):
2815            return torch.cat((x, x))
2816
2817        def outer_fn(x, y):
2818            return inner_fn(x + y).relu()
2819
2820        x, y = [torch.rand((2, 2), dtype=torch.float) for _ in range(2)]
2821        fn_t = torch.jit.trace(outer_fn, (x, y))
2822
2823        # expect that the CallFunction node return type has shape information on it.
2824        FileCheck().check("Float").check("4, 2").check("CallFunction").run(fn_t.graph)
2825        for n in fn_t.graph.nodes():
2826            if n.kind() == "prim::CallFunction":
2827                self.assertTrue(n.output().isCompleteTensor())
2828