xref: /aosp_15_r20/external/pytorch/test/test_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import builtins
4import contextlib
5import copy
6import functools
7import inspect
8import math
9import numbers
10import io
11import operator
12import os
13import pickle
14import sys
15import torch
16import traceback
17import typing
18import types
19import warnings
20import unittest
21from math import sqrt
22from functorch.experimental import control_flow
23from torch.multiprocessing import Process
24from torch.testing import FileCheck
25from torch.testing._internal.common_methods_invocations import op_db
26from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27import torch.utils._pytree as pytree
28import torch.fx._pytree as fx_pytree
29from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
30from torch.fx.node import Target, Argument, _format_arg
31from torch.fx.passes import shape_prop
32from torch.fx.immutable_collections import immutable_dict, immutable_list
33from torch.fx.experimental.rewriter import RewritingTracer
34from torch.fx.operator_schemas import get_signature_for_torch_op
35from copy import deepcopy
36from collections import namedtuple
37
38from torch.fx.proxy import TraceError
39from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
40from torch.fx._symbolic_trace import PHBase, PHWithMeta
41from fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
42from fx.test_dce_pass import TestDCE  # noqa: F401
43from fx.test_fx_const_fold import TestConstFold  # noqa: F401
44from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow  # noqa: F401
45from fx.test_pass_infra import TestPassManager  # noqa: F401
46from fx.test_common_passes import TestCommonPass  # noqa: F401
47from fx.test_cse_pass import TestCSEPass  # noqa: F401
48from fx.test_matcher_utils import TestMatcher  # noqa: F401
49from fx.test_source_matcher_utils import TestSourceMatcher  # noqa: F401
50
51from fx.test_gradual_type import AnnotationsTest  # noqa: F401
52from fx.test_gradual_type import TypeCheckerTest  # noqa: F401
53from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
54from torch.testing._internal.common_utils import (
55    IS_FBCODE,
56    IS_MACOS,
57    IS_WINDOWS,
58    find_library_location,
59    run_tests,
60    skipIfTorchDynamo,
61)
62from torch.testing._internal.jit_utils import JitTestCase
63
64from fx.named_tup import MyNamedTup
65
66try:
67    from torchvision import models as torchvision_models
68    HAS_TORCHVISION = True
69except ImportError:
70    HAS_TORCHVISION = False
71skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
73
74class SimpleTest(torch.nn.Module):
75    def forward(self, x):
76        return torch.relu(x + 3.0)
77
78def a_non_torch_leaf(a, b):
79    return a + b
80
81# Used for test_autowrap_function. Autowrapped functions need to be global
82def fx_int(x: float) -> int:
83    return int(x)
84
85def fx_int_x2(x: float) -> int:
86    return int(x) * 2
87
88# used in test_pytree. It's all the way out here because pickling a GraphModule
89# that uses Point errors out if Point is local to the function
90Point = namedtuple('Point', ['x', 'y'])
91
92# Test wrap() passing both a function name as well as a function
93# directly
94def a_lifted_leaf(a, b):
95    return a[0] + a[1] + b
96
97wrap('a_lifted_leaf')
98# Test wrapping twice doesn't break anything
99wrap('a_lifted_leaf')
100
101def a_lifted_leaf2(a, b):
102    return a[0] + a[1] + b
103
104wrap(a_lifted_leaf2)
105
106wrap('len')
107
108wrap('getattr')
109
110def wrapped_named_tup(p1, *, p2):
111    return p1.x + p2.y
112
113wrap(wrapped_named_tup)
114
115@wrap
116def wrapped_via_decorator(a):
117    return a + 1
118
119wrap('wrapped_with_submodule')
120
121def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
122    return batchnorm1d(x)
123
124def my_decorator(f):
125    @functools.wraps(f)
126    def wrapper_inside_decorator(*args, **kwargs):
127        return f(*args, **kwargs)
128    return wrapper_inside_decorator
129
130@wrap
131@my_decorator
132def wrapped_decorated_fn(x):
133    return x
134
135real_wrapped_via_decorator = wrapped_via_decorator
136real_a_lifed_leaf = a_lifted_leaf
137real_a_lifed_leaf2 = a_lifted_leaf2
138_sqrt = sqrt
139
140wrap('wrapper_fn')
141
142def wrapper_fn(x):
143    return torch.foo(x)
144
145class Pair(NamedTuple):
146    x : torch.Tensor
147    y : torch.Tensor
148
149    def _custom_fx_repr_fn(self) -> str:
150        return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
151
152# for testing pytrees
153class Foo:  # noqa: B209
154    def __init__(self, a, b):
155        self.a = a
156        self.b = b
157
158class Add(torch.nn.Module):
159    def forward(self, x):
160        return x + x
161
162@torch.fx.has_side_effect
163@torch.fx.wrap
164def side_effect_func(x: torch.Tensor):
165    print(x)
166
167class TestFX(JitTestCase):
168    def setUp(self):
169        super().setUp()
170        # Checking for mutable operations whil tracing is feature flagged
171        # Enable it in testing but not by default
172        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
173        torch.fx.proxy.TracerBase.check_mutable_operations = True
174
175        if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
176            lib_file_path = find_library_location('libtorchbind_test.so')
177            torch.ops.load_library(str(lib_file_path))
178
179    def tearDown(self):
180        super().tearDown()
181        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
182
183    def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
184        """Check that an nn.Module's results match the GraphModule version
185        for a given set of args/kwargs.
186        """
187        kwargs = kwargs if kwargs else {}
188        ref_outs = m(*args, **kwargs)
189        gm = symbolic_trace(m)
190        gm.graph.lint()
191        test_outs = gm(*args, **kwargs)
192        self.assertEqual(ref_outs, test_outs)
193
194    def test_graph_module(self):
195        class MySub(torch.nn.Module):
196            def __init__(self) -> None:
197                super().__init__()
198                self.w = torch.nn.Parameter(torch.rand(4, 3))
199
200            def forward(self, x):
201                return self.w + x
202
203        class MyModule(torch.nn.Module):
204            def __init__(self) -> None:
205                super().__init__()
206                self.lin = torch.nn.Linear(4, 3)
207                self.sub_mod = MySub()
208                self.w = torch.nn.Parameter(torch.rand(3))
209
210            def forward(self, A, B, c):
211                t = torch.sigmoid(A) + self.lin(c)
212                return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
213
214        m = MyModule()
215        gm = symbolic_trace(m)
216
217        ms = torch.jit.script(gm)
218
219        class M2(torch.nn.Module):
220            def forward(self, A):
221                m, idx = torch.max(A, 0)
222                return m + 1, idx + 1
223
224        m2 = M2()
225        gm2 = symbolic_trace(m2)
226
227        class T(torch.nn.Module):
228
229            def forward(self, A, b=4, *args, c=5, **kwargs):
230                x = A + 1 + args[0] + kwargs['3']
231                return x
232
233        t = T()
234        symbolic_trace(t)
235
236        # test for issue described at https://github.com/pytorch/pytorch/issues/63883
237        class M3(torch.nn.Module):
238            def forward(self, x):
239                return torch.relu(x)
240
241        m3 = M3()
242        gm3 = symbolic_trace(m3)
243        new_instance = gm3.__new__(type(gm3))
244        new_instance.__init__(gm3, gm3.graph)
245
246        x = torch.randn(5, 3)
247        torch.testing.assert_close(new_instance(x), torch.relu(x))
248
249    def test_informative_co_filename(self):
250        class MyModule(torch.nn.Module):
251            def forward(self, a):
252                return a * 2
253
254        gm = symbolic_trace(MyModule())
255        self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)
256
257    def test_custom_import(self):
258        graph = torch.fx.Graph()
259        a = graph.placeholder('x')
260        b = graph.placeholder('y')
261        c = graph.call_function(a_non_torch_leaf, (a, b))
262        d = graph.call_function(torch.sin, (c,))
263        graph.output(d)
264        gm = GraphModule(torch.nn.Module(), graph)
265        x, y = torch.rand(1), torch.rand(1)
266        self.assertEqual(torch.sin(x + y), gm(x, y))
267
268    def test_args_kwargs(self):
269        class T(torch.nn.Module):
270            def forward(self, *args, **kwargs):
271                x = args[0] + kwargs['foo']
272                return x
273
274        t = T()
275        self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
276
277    def test_varargs_concrete(self):
278        class T(torch.nn.Module):
279            def forward(self, *args, **kwargs):
280                x = args[0] + args[1]
281                return x
282
283        args = (torch.rand(1), torch.rand(1))
284
285        t = T()
286        ref_outs = t(*args)
287        gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
288        gm.graph.lint()
289        test_outs = gm(*args)
290        self.assertEqual(ref_outs, test_outs)
291
292    def test_args_kwargs_no_self(self):
293        class T(torch.nn.Module):
294            def forward(*args, **kwargs):  # noqa: B902
295                self = args[0]
296                return torch.relu(args[1])
297
298        t = T()
299        with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
300            self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
301
302    def test_fx_shifts(self):
303        class MyModule(torch.nn.Module):
304            def forward(self, x):
305                return x << 3, x >> 3
306
307        input = torch.LongTensor(10).random_(0, 1024)
308
309        m = MyModule()
310        self.checkGraphModule(m, (input,))
311
312    def test_fx_and_or(self):
313        class MyModule(torch.nn.Module):
314            def forward(self, x):
315                return x & x, x | x
316
317        input = torch.LongTensor(10).random_(0, 1024)
318
319        m = MyModule()
320        self.checkGraphModule(m, (input,))
321
322    def test_dict(self):
323        class MyDictMod(torch.nn.Module):
324            def forward(self, d):
325                return d['3'].relu(), {'4' : d['3'].neg()}
326
327        input_dict = {'3': torch.rand(3, 4)}
328        m = MyDictMod()
329
330        self.checkGraphModule(m, (input_dict,))
331
332    def test_matmul_tracing(self):
333        const = torch.randn(3)
334
335        def matmul_f(x):
336            return x @ const
337
338        mod = symbolic_trace(matmul_f)
339        inp = torch.randn(3)
340        self.assertEqual(mod(inp), matmul_f(inp))
341
342        def rmatmul_f(x):
343            return const @ x
344
345        mod = symbolic_trace(rmatmul_f)
346        inp = torch.randn(3)
347        self.assertEqual(mod(inp), rmatmul_f(inp))
348
349    @skipIfNoDynamoSupport
350    def test_control_flow_tracing(self):
351        def true(x, y):
352            return x + y
353
354        def false(x, y):
355            return x - y
356
357        def f(x, y):
358            x = control_flow.cond(x[0] == 0, true, false, [x, y])
359
360        with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"):
361            _ = symbolic_trace(f)
362
363    def test_disallow_override(self):
364        # Custom delegate to disallow in-place tensor operations
365        class NoMutableCallTracer(Tracer):
366            def create_node(self, kind : str, target : Union[str, Callable],
367                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
368                            type_expr : Optional[Any] = None) -> Node:
369                name = target if isinstance(target, str) else torch.typename(target)
370                if name[-1] == '_':
371                    raise RuntimeError('In-place operations are not supported')
372                return super().create_node(kind, target, args, kwargs, name)
373
374        # Test method
375        class MyInplaceMod(torch.nn.Module):
376            def forward(self, x):
377                x.add_(3.0)
378                return x
379
380        m = MyInplaceMod()
381
382        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
383            NoMutableCallTracer().trace(m)
384
385        # Test free function
386        class MyInplaceMod2(torch.nn.Module):
387            def forward(self, x):
388                torch.log_(x)
389                return x
390        m2 = MyInplaceMod2()
391        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
392            NoMutableCallTracer().trace(m2)
393
394        # Test symbolic node as an arg
395        class MyInplaceMod3(torch.nn.Module):
396            def forward(self, x):
397                y = torch.ones(3, 4)
398                y.add_(x)
399                return x
400        m3 = MyInplaceMod3()
401        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
402            NoMutableCallTracer().trace(m3)
403
404    def test_leaf_module(self):
405        # Custom delegate to make it so that there are no leaf modules, everything
406        # should get traced through
407        class NoLeafModulesTracer(Tracer):
408            def is_leaf_module(self, m, qualname):
409                return False
410
411        class MyReluMod(torch.nn.Module):
412            def __init__(self) -> None:
413                super().__init__()
414                self.relu = torch.nn.ReLU()
415
416            def forward(self, x):
417                return self.relu(x)
418
419        mrm = MyReluMod()
420        sym = NoLeafModulesTracer().trace(mrm)
421        for node in sym.nodes:
422            self.assertNotEqual(node.op, 'call_module')
423        sym.lint()
424
425    def test_wrap(self):
426        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
427
428        def to_trace(y):
429            return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
430
431        m = symbolic_trace(to_trace)
432        self.assertIn('a_lifted_leaf', m.code)
433        self.assertEqual(27, m(2))
434        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
435
436    def test_wrap_fn_directly(self):
437        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
438
439        def to_trace(y):
440            return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
441
442        m = symbolic_trace(to_trace)
443        self.assertIn('a_lifted_leaf2', m.code)
444        self.assertEqual(27, m(2))
445        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
446
447    def test_wrapped_via_decorator(self):
448        self.assertEqual(wrapped_via_decorator(0), 1)
449
450        def to_trace(y):
451            return wrapped_via_decorator(y)
452
453        m = symbolic_trace(to_trace)
454        self.assertIn('wrapped_via_decorator', m.code)
455        self.assertEqual(m(0), 1)
456        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
457        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
458
459    def test_wrapped_via_decorator_and_transformed(self):
460        self.assertEqual(wrapped_via_decorator(0), 1)
461
462        def to_trace(y):
463            return wrapped_via_decorator(y)
464
465        m = symbolic_trace(to_trace)
466        self.assertIn('wrapped_via_decorator', m.code)
467        self.assertEqual(m(0), 1)
468        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
469        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
470
471        transformed = torch.fx.Transformer(m).transform()
472        self.assertIn('wrapped_via_decorator', transformed.code)
473        self.assertEqual(transformed(0), 1)
474        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
475        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
476
477    def test_wrap_with_submodule(self):
478
479        class M(torch.nn.Module):
480            def __init__(self) -> None:
481                super().__init__()
482                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
483
484            def forward(self, x: torch.Tensor):
485                return wrapped_with_submodule(x, self.batchnorm1d)
486
487        m = symbolic_trace(M())
488
489        self.assertIn("wrapped_with_submodule", m.code)
490
491        input = torch.rand(3, 2)
492        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
493        self.assertEqual(ref_batchnorm1d(input), m(input))
494
495    def test_wrapped_retrace(self):
496        def to_trace(y):
497            return wrapped_via_decorator(y)
498
499        m = symbolic_trace(to_trace)
500        self.assertIn('wrapped_via_decorator', m.code)
501        self.assertEqual(m(0), 1)
502
503        retraced = symbolic_trace(m)
504        self.assertIn('wrapped_via_decorator', retraced.code)
505        self.assertEqual(retraced(0), 1)
506
507    def test_wrap_decorated_function(self):
508        def to_trace(y):
509            return wrapped_decorated_fn(y)
510
511        m = symbolic_trace(to_trace)
512        self.assertIn('wrapped_decorated_fn', m.code)
513        self.assertEqual(m(1), 1)
514
515    def test_graph_edit_with_proxy(self):
516        class M(torch.nn.Module):
517            def forward(self, a, b):
518                return a + b
519        m = M()
520        g = symbolic_trace(m).graph
521        new_g = torch.fx.Graph()
522        val_map : Dict[Node, Node] = {}
523        output_val = new_g.graph_copy(g, val_map)
524        t = Proxy(output_val)
525        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
526        new_g.output((t + t).node)
527        gm = GraphModule(m, new_g)
528        gm.graph.lint()
529        self.assertEqual(gm(3, 4), 14)
530
531    def test_proxy_deepcopy_without_tracer(self):
532        class MyModule(torch.nn.Module):
533            def __init__(self):
534                super().__init__()
535
536            def forward(self, x):
537                return 2 * x
538
539        module = MyModule()
540        traced = symbolic_trace(module)
541        node = list(traced.graph.nodes)[-2]
542        p = torch.fx.Proxy(node, None)
543        node.proxy = p
544        p2 = copy.deepcopy(p)
545        self.assertTrue(isinstance(p2, torch.fx.Proxy))
546        self.assertEqual(p2.node.name, node.name)
547        self.assertEqual(p2.node.target, node.target)
548        self.assertNotEqual(id(p2.node), id(node))
549
550    def test_proxy_deepcopy_with_tracer(self):
551        class TestTracer(Tracer):
552            def __init__(self, name):
553                super().__init__()
554                self.name = name
555
556            def is_leaf_module(self, module, name):
557                return True
558
559        class MyModule(torch.nn.Module):
560            def __init__(self):
561                super().__init__()
562
563            def forward(self, x):
564                return 2 * x
565
566        module = MyModule()
567        tracer = TestTracer("mytracer")
568        traced = symbolic_trace(module)
569        node = list(traced.graph.nodes)[-2]
570        p = torch.fx.Proxy(node, tracer)
571        node.proxy = p
572        p2 = copy.deepcopy(p)
573        self.assertTrue(isinstance(p2, torch.fx.Proxy))
574        self.assertTrue(isinstance(p2.tracer, torch.fx._symbolic_trace.Tracer))
575        self.assertEqual(p2.tracer.name, "mytracer")
576        self.assertEqual(p2.node.name, node.name)
577        self.assertEqual(p2.node.target, node.target)
578        self.assertNotEqual(id(p2.node), id(node))
579        self.assertNotEqual(id(p2.tracer), id(tracer))
580
581    def test_concrete_arg_none_assert(self):
582        class Foo(torch.nn.Module):
583            def forward(self, x, val=None):
584                return x if val is None else x + val
585
586        f = Foo()
587        traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
588        with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
589            traced(torch.randn(5), torch.randn(5))
590
591        x = torch.randn(5)
592        torch.testing.assert_close(traced(x), f(x))
593
594    def test_trace_multiple_funcs(self):
595        class Foo(torch.nn.Module):
596            def forward(self, x, y):
597                return x + y
598
599            def minus_forward(self, x, y):
600                return x - y
601
602            def multiply_forward(self, x, y):
603                return x * y
604
605        f = Foo()
606        x, y = torch.randn(5), torch.randn(5)
607
608        print(torch.__version__)
609
610        tracer = Tracer()
611        torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
612
613        tracer.traced_func_name = "minus_forward"
614        torch.testing.assert_close(
615            GraphModule(f, tracer.trace(f))(x, y),
616            f.minus_forward(x, y),
617        )
618
619        tracer.traced_func_name = "multiply_forward"
620        torch.testing.assert_close(
621            GraphModule(f, tracer.trace(f))(x, y),
622            f.multiply_forward(x, y),
623        )
624
625        tracer.traced_func_name = "add_forward"
626        with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
627            tracer.trace(f)
628
629    def test_graph_unique_names(self):
630        class M(torch.nn.Module):
631            def forward(self, a, b):
632                return a + b
633        m = M()
634        g = symbolic_trace(m).graph
635        new_g = torch.fx.Graph()
636        val_map : Dict[Node, Node] = {}
637        output_val = new_g.graph_copy(g, val_map)
638        t = Proxy(output_val)
639        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
640        new_g.output((t + t).node)
641        gm = GraphModule(m, new_g)
642        seen_names : Set[str] = set()
643        for node in gm.graph.nodes:
644            assert node.name not in seen_names
645            seen_names.add(node.name)
646
647    def test_stack_traces(self):
648        class M(torch.nn.Module):
649            def forward(self, a, b):
650                return a + b
651
652        tracer = torch.fx.Tracer()
653        tracer.record_stack_traces = True
654
655        graph = tracer.trace(M())
656        # saving the original list because we will insert new nodes as a part of a test
657        orig_graph_nodes = list(graph.nodes)
658        for node in orig_graph_nodes:
659            if node.op == 'output':
660                continue
661            self.assertTrue(node.stack_trace is not None)
662            assert 'test_fx.py' in node.stack_trace
663
664            # verify that copying the node does not lose the stack trace
665            new_node = graph.node_copy(node)
666            self.assertTrue(new_node.stack_trace is not None)
667            assert 'test_fx.py' in new_node.stack_trace
668
669    def test_stack_traces_with_transformer(self):
670        class M(torch.nn.Module):
671            def forward(self, a, b):
672                return a + b
673
674        tracer = torch.fx.Tracer()
675        tracer.record_stack_traces = True
676
677        graph = tracer.trace(M())
678        gm = GraphModule(tracer.root, graph)
679        new_gm = Transformer(gm).transform()
680
681        # nodes after Transformer should still preserve the original node's stack trace
682        for node in new_gm.graph.nodes:
683            if node.op in {'placeholder', 'output'}:
684                continue
685            self.assertTrue(node.stack_trace is not None)
686            assert 'test_fx.py' in node.stack_trace
687
688    def test_lineno_map(self):
689        class M(torch.nn.Module):
690            def forward(self, a, b):
691                a = torch.sin(a)
692                b = torch.cos(b)
693                return a + b
694
695        tracer = torch.fx.Tracer()
696        graph = tracer.trace(M())
697        gm = GraphModule(tracer.root, graph)
698        expected = {1: 2, 2: 3, 3: 4, 4: 5}
699        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
700
701        # test custom codegen
702        def transform_code(code):
703            return ["print('hello!')\n", *code]
704        gm.graph.on_generate_code(lambda _: transform_code)
705        gm.recompile()
706        expected = {2: 2, 3: 3, 4: 4, 5: 5}
707        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
708
709    def test_graph_unique_names_manual(self):
710        graph : torch.fx.Graph = torch.fx.Graph()
711        a : torch.fx.Node = graph.create_node('placeholder', 'x')
712        b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
713        c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
714        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
715        graph.output(d)
716        graph2 = torch.fx.Graph()
717        val_map : Dict[Node, Node] = {}
718        graph2.graph_copy(graph, val_map)
719        seen_names : Set[str] = set()
720        for node in graph2.nodes:
721            assert node.name not in seen_names
722            seen_names.add(node.name)
723
724    def test_unpack(self):
725        class M(torch.nn.Module):
726            def forward(self, a, b):
727                c, d = a
728                return c + d + b
729
730        a = (torch.rand(1), torch.rand(1))
731        b = torch.rand(1)
732        m = M()
733        self.checkGraphModule(m, (a, b))
734
735    def test_native_callable(self):
736        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
737            raise unittest.SkipTest("non-portable load_library call used in test")
738        # This test exercises the case where we use FX to translate from Python
739        # code to some native callable object
740        #
741        # For the purposes of testing, we use ElementwiseInterpreter defined
742        # in test_custom_class.cpp.
743        #
744        # We test that we can
745        # 1) Construct a native callable from FX IR
746        # 2) Construct a drop-in replacement module that delegates to the
747        #    native callable rather than the original code
748        # 3) Run both the original code and native callable wrapper with
749        #    equivalent results
750        # 4) TorchScript compile the native callable wrapper and confirm
751        #    equivalent results with the reference
752        # 5) TorchScript serialize and deserialize the native callable
753        #    and confirm equivalent results with the reference
754
755        # We use this simple Module as a reference computation
756        class MySimpleMod(torch.nn.Module):
757            def forward(self, x):
758                return 3.0 * x + x
759
760        msm = MySimpleMod()
761
762        # This is what a lowering pass might look like: a function that takes
763        # a valid nn.Module, symbolically traces it, lowers the Module to some
764        # representation, and wraps that representation up into another
765        # nn.Module instance that handles dispatch to the compiled/lowered code.
766        def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
767            # ===== Stage 1: Symbolic trace the module =====
768            mod = symbolic_trace(orig_mod)
769
770            # ===== Stage 2: Lower GraphModule representation to the C++
771            #       interpreter's instruction format ======
772            instructions = []
773            constant_idx = 0
774            constants = {}
775            fn_input_names = []
776
777            target_to_name = {
778                operator.add : "add",
779                operator.mul : "mul"
780            }
781
782            output_node : Optional[Node] = None
783            # For each instruction, create a triple
784            # (instruction_name : str, inputs : List[str], output : str)
785            # to feed into the C++ interpreter
786            for n in mod.graph.nodes:
787                target, args, out_name = n.target, n.args, n.name
788                assert len(n.kwargs) == 0, "kwargs currently not supported"
789
790                if n.op == 'placeholder':
791                    # Placeholders specify function argument names. Save these
792                    # for later when we generate the wrapper GraphModule
793                    fn_input_names.append(target)
794                elif n.op == 'call_function':
795                    assert target in target_to_name, "Unsupported call target " + target
796                    arg_names = []
797                    for arg in args:
798                        if not isinstance(arg, Node):
799                            # Pull out constants. These constants will later be
800                            # fed to the interpreter C++ object via add_constant()
801                            arg_name = f'constant_{constant_idx}'
802                            constants[arg_name] = torch.tensor(
803                                [arg] if isinstance(arg, numbers.Number) else arg)
804                            arg_names.append(arg_name)
805                            constant_idx += 1
806                        else:
807                            arg_names.append(arg.name)
808                    instructions.append((target_to_name[target], arg_names, out_name))
809                elif n.op == 'output':
810                    if output_node is not None:
811                        raise RuntimeError('Multiple output nodes!')
812                    output_node = n
813                else:
814                    raise RuntimeError('Unsupported opcode ' + n.op)
815
816            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
817            # Load constants
818            for k, v in constants.items():
819                interpreter.add_constant(k, v)
820            # Specify names for positional input arguments
821            interpreter.set_input_names(fn_input_names)
822            # Load instructions
823            interpreter.set_instructions(instructions)
824            # Specify name for single output
825            assert isinstance(output_node.args[0], torch.fx.Node)
826            interpreter.set_output_name(output_node.args[0].name)
827
828            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
829            class WrapperModule(torch.nn.Module):
830                def __init__(self, interpreter):
831                    super().__init__()
832                    self.interpreter = interpreter
833
834            wrapper = WrapperModule(interpreter)
835
836            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
837            # 3) Returns the speficied return value
838
839            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
840            # the wrapper with a Tracer that considers the Wrapper instance a root
841            # module, however, I can't get `__call__` exposed on TorchBind classes
842            # without it messing up Python `hasattr` for some reason. More digging
843            # into CPython's implementation of hasattr is probably in order...
844
845            graph = torch.fx.Graph()
846            # Add placeholders for fn inputs
847            placeholder_nodes = []
848            for name in fn_input_names:
849                placeholder_nodes.append(graph.create_node('placeholder', name))
850
851            # Get the interpreter object
852            interpreter_node = graph.create_node('get_attr', 'interpreter')
853
854            # Add a node to call the interpreter instance
855            output_node = graph.create_node(
856                op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
857
858            # Register output
859            graph.output(output_node)
860
861            graph.lint()
862
863            # Return final GraphModule!!!
864            return GraphModule(wrapper, graph)
865
866        # Lower GraphModule to C++ interpreter
867        lowered = lower_to_elementwise_interpreter(msm)
868
869        # Compare correctness with original module
870        x = torch.rand(3, 4)
871        ref_out = msm(x)
872        test_out = lowered(x)
873        torch.testing.assert_close(test_out, ref_out)
874
875        # Test TorchScript compilation
876        scripted_lowered = torch.jit.script(lowered)
877        script_out = scripted_lowered(x)
878        torch.testing.assert_close(script_out, ref_out)
879
880        # Test TorchScript ser/de
881        import_copy = self.getExportImportCopy(scripted_lowered)
882        imported_out = import_copy(x)
883        torch.testing.assert_close(imported_out, ref_out)
884
885    def test_reserved_getattr(self):
886        """Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
887        class M(torch.nn.Module):
888            def forward(self, a):
889                return a.foo.bar.baz
890
891        m = M()
892        m_g = symbolic_trace(m)
893        m_g.graph.lint()
894        for node in m_g.graph.nodes:
895            self.assertTrue(node.name != "getattr")
896
897    @unittest.skip("Hotfix for SEV remediation")
898    def test_trace_buffer_slice(self):
899        bs, d_hid = 10, 23
900
901        class ExampleCode(torch.nn.Module):
902            def __init__(self) -> None:
903                super().__init__()
904                self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
905                self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
906                self.lin = torch.nn.Linear(d_hid, d_hid)
907                self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid))
908
909            def forward(self, x):
910                x = torch.mm(x, self.mm_param)
911                skip_connection = x
912                x = torch.relu(x)
913                x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
914                x = self.lin(x)
915                x = torch.relu(x)
916                x = x + skip_connection
917                x = torch.mm(x, self.mm_param2)
918                x = self.lin(x)
919                return x
920
921        ec = ExampleCode()
922
923        traced = torch.fx.symbolic_trace(ec)
924
925        x = torch.randn(bs, d_hid)
926        torch.testing.assert_close(ec(x), traced(x))
927
928    def test_node_tagging(self):
929        class TaggingTracer(Tracer):
930            def create_node(self, kind : str, target : Union[str, Callable],
931                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
932                            type_expr : Optional[Any] = None) -> Node:
933                n = super().create_node(kind, target, args, kwargs, name)
934                n.tag = 'foo'
935                return n
936
937        class M(torch.nn.Module):
938            def forward(self, a, b):
939                return a + b
940
941        m = M()
942        g = TaggingTracer().trace(m)
943        g.lint()
944        for n in g.nodes:
945            self.assertTrue(hasattr(n, 'tag'))
946            self.assertEqual(n.tag, 'foo')
947
948    def test_tensor_attribute(self):
949        class TensorAttribute(torch.nn.Module):
950            def __init__(self) -> None:
951                super().__init__()
952                self.tensor = torch.rand(3, 4)
953
954            def forward(self, x):
955                return torch.nn.functional.linear(x, self.tensor)
956
957        ta = TensorAttribute()
958        traced = symbolic_trace(ta)
959        traced(torch.rand(4, 4))
960
961        class WrapperForQualname(torch.nn.Module):
962            def __init__(self) -> None:
963                super().__init__()
964                self.ta = TensorAttribute()
965
966            def forward(self, x):
967                return torch.nn.functional.linear(x, self.ta.tensor)
968
969        wfq = WrapperForQualname()
970        traced2 = symbolic_trace(wfq)
971        traced2.graph.lint()
972        traced2(torch.rand(4, 4))
973
974    def test_tensor_attribute_coalseced(self):
975
976        def count_attrs(fx_module):
977            targets = set()
978            for node in traced.graph.nodes:
979                if node.op == 'get_attr':
980                    targets.add(node.target)
981            return len(targets)
982
983        val = torch.tensor(5)
984
985        def f(x):
986            return x + val + val
987        traced = symbolic_trace(f)
988        traced.graph.lint()
989        self.assertEqual(count_attrs(traced), 1)
990
991        val2 = torch.tensor(5)
992
993        def f(x):
994            val = torch.tensor(5)
995            return x + val + val2
996
997        traced = symbolic_trace(f)
998        traced.graph.lint()
999        self.assertEqual(count_attrs(traced), 2)
1000
1001    def test_symbolic_trace_sequential(self):
1002        class Simple(torch.nn.Module):
1003            def forward(self, x):
1004                return torch.neg(x)
1005
1006        seq = torch.nn.Sequential(
1007            Simple(),
1008            Simple(),
1009            Simple()
1010        )
1011        traced = symbolic_trace(seq)
1012        traced.graph.lint()
1013        x = torch.rand(3, 4)
1014        self.assertEqual(traced(x), seq(x))
1015
1016    def test_tensor_constant(self):
1017        class ConstTensor(torch.nn.Module):
1018            def forward(self, x):
1019                return torch.nn.functional.linear(x, torch.zeros(3, 4))
1020
1021        ct = ConstTensor()
1022        traced = symbolic_trace(ct)
1023        traced.graph.lint()
1024        traced(torch.rand(4, 4))
1025
1026    def test_pickle_graphmodule(self):
1027        class Nested(torch.nn.Module):
1028            def __init__(self) -> None:
1029                super().__init__()
1030                self.st = torch.nn.Linear(4, 4)
1031
1032            def forward(self, x):
1033                return self.st(x)
1034
1035        n = Nested()
1036        traced = symbolic_trace(n)
1037        traced.graph.lint()
1038        pickled = pickle.dumps(traced)
1039        loaded = pickle.loads(pickled)
1040        loaded.graph.lint()
1041        x = torch.rand(3, 4)
1042        self.assertEqual(loaded(x), traced(x))
1043
1044    def test_pickle_custom_import(self):
1045        graph = torch.fx.Graph()
1046        a = graph.placeholder('x')
1047        b = graph.placeholder('y')
1048        c = graph.call_function(a_non_torch_leaf, (a, b))
1049        d = graph.call_function(torch.sin, (c,))
1050        graph.output(d)
1051        gm = GraphModule(torch.nn.Module(), graph)
1052        pickled = pickle.dumps(gm)
1053        loaded = pickle.loads(pickled)
1054        loaded.graph.lint()
1055        x, y = torch.rand(1), torch.rand(1)
1056        self.assertEqual(loaded(x, y), gm(x, y))
1057
1058    def test_all_input_nodes(self):
1059        graph : torch.fx.Graph = torch.fx.Graph()
1060        a : torch.fx.Node = graph.placeholder('x')
1061        b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
1062        c : torch.fx.Node = graph.get_attr('y_attr')
1063        d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
1064        e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
1065        graph.output(e)
1066        graph.lint()
1067
1068        self.assertEqual(b.all_input_nodes, [a])
1069        self.assertEqual(c.all_input_nodes, [])
1070        self.assertEqual(d.all_input_nodes, [b, c])
1071        self.assertEqual(e.all_input_nodes, [d])
1072
1073    def test_deepcopy_graphmodule_with_transform(self):
1074        st = SimpleTest()
1075        traced = symbolic_trace(st)
1076        traced.graph.lint()
1077
1078        def transform(traced):
1079            new_graph = torch.fx.Graph()
1080            val_map : Dict[Node, Node] = {}
1081            output_value = new_graph.graph_copy(traced.graph, val_map)
1082            relu_out = new_graph.create_node(
1083                op='call_method', target='neg', args=(output_value,), kwargs={})
1084            new_graph.output(relu_out)
1085            return GraphModule(traced, new_graph)
1086        transformed = transform(traced)
1087        transformed.graph.lint()
1088        copied = copy.deepcopy(transformed)
1089        self.assertNotEqual(id(type(transformed)), id(type(copied)))
1090        x = torch.randn(3, 4)
1091        self.assertEqual(copied(x), transformed(x))
1092
1093    def test_deepcopy_with_submods_params(self):
1094        class Bar(torch.nn.Module):
1095            def __init__(self) -> None:
1096                super().__init__()
1097                self.param = torch.nn.Parameter(torch.rand(3, 4))
1098
1099            def forward(self, x):
1100                return torch.relu(x) + self.param
1101
1102        class Baz(torch.nn.Module):
1103            def __init__(self) -> None:
1104                super().__init__()
1105                self.param = torch.nn.Parameter(torch.rand(3, 4))
1106                self.bar = Bar()
1107
1108            def forward(self, x):
1109                return self.bar(x) - self.param
1110
1111        baz = Baz()
1112        traced = symbolic_trace(baz)
1113        traced.graph.lint()
1114        copied = copy.deepcopy(traced)
1115        copied.graph.lint()
1116
1117    def test_deepcopy_graph_with_tracer_cls(self):
1118        class TestTracer(Tracer):
1119            def is_leaf_module(self, module, name):
1120                return True
1121
1122        g = Graph(tracer_cls=TestTracer)
1123        x = g.placeholder("x")
1124        g.output(x)
1125
1126        h = copy.deepcopy(g)
1127        self.assertIsNotNone(h._tracer_cls)
1128        self.assertTrue(g._tracer_cls == h._tracer_cls)
1129
1130    def test_unpack_list_better_error(self):
1131        class SomeArgs(torch.nn.Module):
1132            def forward(self, a, b):
1133                return torch.rand(3, 4)
1134
1135        class UnpacksList(torch.nn.Module):
1136            def __init__(self) -> None:
1137                super().__init__()
1138                self.sa = SomeArgs()
1139
1140            def forward(self, x : list):
1141                return self.sa(*x)
1142
1143        ul = UnpacksList()
1144        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1145            symbolic_trace(ul)
1146
1147    def test_unpack_dict_better_error(self):
1148        class SomeKwargs(torch.nn.Module):
1149            def forward(self, x=3, y=4):
1150                return torch.rand(3, 4)
1151
1152        class UnpacksDict(torch.nn.Module):
1153            def __init__(self) -> None:
1154                super().__init__()
1155                self.sk = SomeKwargs()
1156
1157            def forward(self, x : dict):
1158                return self.sk(**x)
1159
1160        ud = UnpacksDict()
1161        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1162            symbolic_trace(ud)
1163
1164    def test_pretty_print_targets(self):
1165        # Test that Graph pretty-print prints friendly name for targets
1166        # in `operator` and `builtins`
1167
1168        class SomeMod(torch.nn.Module):
1169            def forward(self, x):
1170                return torch.add(x.foo + x.bar, 3.0)
1171
1172        traced = symbolic_trace(SomeMod())
1173        graph_str = str(traced.graph)
1174        self.assertIn('builtins.getattr', graph_str)
1175        self.assertIn('operator.add', graph_str)
1176        self.assertIn('torch.add', graph_str)
1177
1178    def test_pretty_print_node(self):
1179        class M(torch.nn.Module):
1180            def __init__(self) -> None:
1181                super().__init__()
1182                self.param: torch.nn.Parameter = torch.nn.Parameter(
1183                    torch.rand(3, 4))
1184                self.linear = torch.nn.Linear(4, 5)
1185
1186            def forward(self, x: torch.Tensor, y: int = 2):
1187                return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
1188
1189        traced = symbolic_trace(M())
1190
1191        all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
1192
1193        FileCheck().check("x").check("placeholder") \
1194            .check("y").check("placeholder") \
1195            .check("getitem").check("call_function") \
1196            .check("param").check("get_attr") \
1197            .check("add").check("call_function") \
1198            .check("linear").check("call_module") \
1199            .check("clamp").check("call_method") \
1200            .run(all_formatted)
1201
1202    def test_script_tensor_constant(self):
1203        # TorchScript seems to ignore attributes that start with `__`.
1204        # We used to call anonymous Tensor values `__tensor_constant*`, but
1205        # they were getting ignored by script. Now they're called
1206        # `_tensor_constant*`
1207        class IHaveATensorConstant(torch.nn.Module):
1208            def forward(self, x):
1209                return x + torch.rand(3, 4)
1210
1211        traced = torch.fx.symbolic_trace(IHaveATensorConstant())
1212        torch.jit.script(traced)
1213
1214    def test_autowrap_functions(self):
1215        class AutowrapFnTest(torch.nn.Module):
1216            def forward(self, x):
1217                return fx_int(x.shape[0] / 2)
1218
1219        class AutowrapFnTest2(torch.nn.Module):
1220            def forward(self, x):
1221                return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
1222
1223        # Check function(s) are wrapped
1224        # `int` would normally throw a TypeError as argument can't be `Proxy`
1225        tracer = Tracer(autowrap_functions=(fx_int,))
1226        graph = tracer.trace(AutowrapFnTest())
1227        traced = GraphModule(tracer.root, graph, 'test')
1228        tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
1229        tracer_2.trace(AutowrapFnTest2())
1230
1231        # Test scriptability
1232        traced_scripted = torch.jit.script(traced)
1233        self.assertEqual(traced_scripted(torch.rand(4)), 2)
1234
1235    def test_tuple_no_subscript(self):
1236        def foo(x : Tuple):
1237            return x[0]
1238
1239        traced = torch.fx.symbolic_trace(foo)
1240        x = (torch.randn(5, 3),)
1241        torch.testing.assert_close(traced(x), x[0])
1242
1243        bio = io.BytesIO()
1244
1245        torch.save(traced, bio)
1246
1247        bio.seek(0)
1248
1249        # weights_only=False as this loads a GraphModule
1250        # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
1251        loaded = torch.load(bio, weights_only=False)
1252
1253        torch.testing.assert_close(loaded(x), x[0])
1254
1255    def test_torch_fx_len(self):
1256        class FXLenTest(torch.nn.Module):
1257            def forward(self, x):
1258                return len(x)
1259
1260        traced = symbolic_trace(FXLenTest())
1261        self.assertEqual(traced(torch.rand(3, 4)), 3)
1262
1263        # Test scriptability
1264        scripted = torch.jit.script(FXLenTest())
1265        self.assertEqual(scripted(torch.rand(3)), 3)
1266
1267        traced_scripted = torch.jit.script(traced)
1268        self.assertEqual(traced_scripted(torch.rand(3)), 3)
1269
1270        # Test non-proxy len
1271        class FXLenTest2(torch.nn.Module):
1272            def __init__(self) -> None:
1273                super().__init__()
1274                self.l = [3, 4, 5]
1275
1276            def forward(self, x):
1277                return x + len(self.l)
1278
1279        traced2 = symbolic_trace(FXLenTest2())
1280        inp = torch.rand(3, 4)
1281        self.assertEqual(traced2(inp), inp + 3.0)
1282        self.assertIs(len, builtins.len)
1283
1284    def test_torch_fx_getattr(self):
1285        class FXGetattrTest(torch.nn.Module):
1286            def forward(self, x):
1287                return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
1288
1289        traced = symbolic_trace(FXGetattrTest())
1290        self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
1291
1292    def test_sqrt(self):
1293        class Sqrt1(torch.nn.Module):
1294            def forward(self, x):
1295                return sqrt(x.size(0))
1296
1297        class Sqrt2(torch.nn.Module):
1298            def forward(self, x):
1299                return math.sqrt(x.size(0))
1300
1301        class Sqrt3(torch.nn.Module):
1302            def forward(self, x):
1303                return x + math.sqrt(2) + sqrt(2)
1304
1305        self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
1306        self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
1307        self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
1308        self.assertIs(sqrt, _sqrt)
1309        self.assertIs(math.sqrt, _sqrt)
1310
1311    def test_torch_custom_ops(self):
1312        class M(torch.nn.Module):
1313            def forward(self, a):
1314                b = torch.ops.aten.sigmoid(a)
1315                c = torch.ops.aten.cat([a, b])
1316                return torch.ops.aten.cat((c, c))
1317        m = M()
1318        input = torch.randn(3)
1319        ref_out = m(input)
1320        gm = symbolic_trace(m)
1321        gm.graph.lint()
1322        out = gm(input)
1323        self.assertEqual(out, ref_out)
1324
1325    def test_torch_op_overloads(self):
1326        class M(torch.nn.Module):
1327            def forward(self, a):
1328                b = torch.ops.aten.add.Tensor(a, a)
1329                return b
1330        m = M()
1331        input = torch.randn(3)
1332        ref_out = m(input)
1333        gm = symbolic_trace(m)
1334        gm.graph.lint()
1335        out = gm(input)
1336        self.assertEqual(out, ref_out)
1337
1338        for node in gm.graph.nodes:
1339            if node.op == 'call_function':
1340                assert isinstance(node.target, torch._ops.OpOverload)
1341                assert node.target.__name__ == 'add.Tensor'
1342
1343    def test_pickle_torch_custom_ops(self):
1344        class M(torch.nn.Module):
1345            def forward(self, a):
1346                b = torch.ops.aten.sigmoid(a)
1347                c = torch.ops.aten.cat([a, b])
1348                return torch.ops.aten.cat((c, c))
1349        m = M()
1350        input = torch.randn(3)
1351        ref_out = m(input)
1352        gm = symbolic_trace(m)
1353        gm.graph.lint()
1354        pickled = pickle.dumps(gm)
1355        loaded = pickle.loads(pickled)
1356        self.assertEqual(loaded(input), gm(input))
1357
1358    def test_pretty_print(self):
1359        st = SimpleTest()
1360        traced = symbolic_trace(st)
1361        traced.graph.lint()
1362        printed = str(traced)
1363        assert 'SimpleTest()' in printed
1364        assert 'torch.relu' in printed
1365
1366    def test_pretty_print_graph(self):
1367        class KwargPrintTest(torch.nn.Module):
1368            def forward(self, x):
1369                return torch.squeeze(x + 3.0, dim=2)
1370        st = KwargPrintTest()
1371        traced = symbolic_trace(st)
1372        traced.graph.lint()
1373        stringed = str(traced.graph)
1374        for s in ['args', 'kwargs', 'num_users']:
1375            assert s in stringed
1376
1377    def test_custom_proxy_type(self):
1378        class TensorPair:
1379            def __init__(self, left, right):
1380                self.left, self.right = left, right
1381
1382            def add(self, other):
1383                l = self.left + other.left
1384                r = self.right + other.right
1385                return TensorPair(l, r)
1386
1387            def mul(self, other):
1388                l = self.left * other.left
1389                r = self.right * other.right
1390                return TensorPair(l, r)
1391
1392        def use_tensor_pair(x : TensorPair, y : TensorPair):
1393            s = x.add(y)
1394            return s.mul(x)
1395
1396        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1397        y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1398
1399        ref_out = use_tensor_pair(x, y)
1400
1401        traced = symbolic_trace(use_tensor_pair)
1402
1403        traced_out = traced(x, y)
1404        self.assertEqual(traced_out.left, ref_out.left)
1405        self.assertEqual(traced_out.right, ref_out.right)
1406
1407    def test_custom_proxy_type_literal(self):
1408        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1409            def __init__(self, left, right):
1410                self.left, self.right = left, right
1411
1412            def add(self, other):
1413                l = self.left + other.left
1414                r = self.right + other.right
1415                return TensorPair(l, r)
1416
1417            def mul(self, other):
1418                l = self.left * other.left
1419                r = self.right * other.right
1420                return TensorPair(l, r)
1421
1422        def use_tensor_pair_literal(x : TensorPair):
1423            s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
1424            return s.mul(x)
1425
1426        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1427
1428        ref_out = use_tensor_pair_literal(x)
1429
1430        traced = symbolic_trace(use_tensor_pair_literal)
1431
1432        traced_out = traced(x)
1433        self.assertEqual(traced_out.left, ref_out.left)
1434        self.assertEqual(traced_out.right, ref_out.right)
1435
1436    def test_custom_proxy_dynamic_value(self):
1437        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1438            def __init__(self, left, right):
1439                self.left, self.right = left, right
1440
1441            def add(self, other):
1442                l = self.left + other.left
1443                r = self.right + other.right
1444                return TensorPair(l, r)
1445
1446            def mul(self, other):
1447                l = self.left * other.left
1448                r = self.right * other.right
1449                return TensorPair(l, r)
1450
1451        def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
1452            s = x.add(TensorPair(y, y))
1453            return s.mul(x)
1454
1455        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1456        y = torch.randn(5, 3)
1457        ref_out = use_tensor_pair_ctor(x, y)
1458
1459        traced = symbolic_trace(use_tensor_pair_ctor)
1460
1461        traced_out = traced(x, y)
1462        self.assertEqual(traced_out.left, ref_out.left)
1463        self.assertEqual(traced_out.right, ref_out.right)
1464
1465    def test_custom_proxy_input_dependent_control_flow(self):
1466        class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
1467            def __init__(self, inp):
1468                if inp.sum() == 0:
1469                    self.is_zero = True
1470                    self.tensor = torch.tensor([])
1471                else:
1472                    self.is_zero = False
1473                    self.tensor = inp
1474
1475            def add(self, other):
1476                if self.is_zero:
1477                    return ZeroTensor(other.tensor)
1478                elif other.is_zero:
1479                    return self
1480
1481        def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
1482            return ZeroTensor(x + y)
1483
1484        x, y = torch.randn(5, 3), torch.randn(5, 3)
1485
1486        ref_out = use_zero_tensor(x, y)
1487
1488        traced = symbolic_trace(use_zero_tensor)
1489
1490        traced_out = traced(x, y)
1491
1492        self.assertEqual(traced_out.is_zero, ref_out.is_zero)
1493        self.assertEqual(traced_out.tensor, ref_out.tensor)
1494
1495    def test_graph_fns(self):
1496        g = Graph()
1497        a = g.placeholder('a')
1498        b = g.call_module('linear', (a,))
1499        c = g.get_attr('bias')
1500        d = g.call_method('add', (b, c))
1501        e = g.call_function(torch.sin, (d,))
1502        g.output(e)
1503        mod = torch.nn.Module()
1504        mod.linear = torch.nn.Linear(3, 4)
1505        mod.bias = torch.rand(4)
1506        gm = GraphModule(mod, g)
1507        gm.graph.lint()
1508        input = torch.rand(3)
1509        r = gm(input)
1510        ref = torch.sin(mod.linear(input) + mod.bias)
1511        self.assertEqual(r, ref)
1512
1513    def test_remove_uses(self):
1514        g : torch.fx.Graph = Graph()
1515        x : torch.fx.Node = g.placeholder('x')
1516        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1517        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1518        g.output(neg)
1519
1520        neg.replace_all_uses_with(relu)
1521        g.erase_node(neg)
1522
1523        self.assertTrue(neg not in relu.users)
1524
1525    def test_remove_uses_with_custom_filter(self):
1526        g : torch.fx.Graph = Graph()
1527        x : torch.fx.Node = g.placeholder('x')
1528        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1529        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1530        g.output(neg)
1531
1532        neg.replace_all_uses_with(relu, lambda x: x != neg)
1533
1534        self.assertTrue(neg in relu.users)
1535
1536    def test_nonetype_annotation(self):
1537        eb = torch.nn.EmbeddingBag(3, 4)
1538        symbolic_trace(eb)
1539
1540    def test_pickle_nonetype_annotation(self):
1541        eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
1542        traced = symbolic_trace(eb)
1543        pickled = pickle.dumps(traced)
1544        loaded = pickle.loads(pickled)
1545        loaded.graph.lint()
1546        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1547        offsets = torch.LongTensor([0, 4])
1548        self.assertEqual(loaded(input, offsets), traced(input, offsets))
1549
1550    def test_return_tuple(self):
1551        class M(torch.nn.Module):
1552            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1553                return (x, x + x)
1554
1555        original = M()
1556        traced = symbolic_trace(original)
1557        self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
1558
1559    def test_construct_root_dict(self):
1560        graph : torch.fx.Graph = torch.fx.Graph()
1561        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1562        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1563        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1564        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1565        graph.output(d)
1566
1567        linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
1568        add_param : torch.Tensor = torch.rand(3, 4)
1569        gm : torch.fx.GraphModule = torch.fx.GraphModule(
1570            {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
1571        gm.graph.lint()
1572
1573        assert 'self.foo.bar.baz' in gm.code
1574
1575        x : torch.Tensor = torch.rand(3, 3)
1576        out : torch.Tensor = gm(x)
1577        ref_out : torch.Tensor = linear_mod(x) + add_param
1578        self.assertEqual(out, ref_out)
1579
1580    def test_symbolic_trace_assert(self):
1581
1582        class AssertsTensorShape(torch.nn.Module):
1583            def forward(self, x):
1584                torch._assert(x.shape[1] > 4, "assert_foobar")
1585                return x
1586
1587        m = AssertsTensorShape()
1588        # verify traceability
1589        traced = symbolic_trace(m)
1590        # verify assertion on traced model works correctly at runtime
1591        traced(torch.rand(4, 5))
1592        with self.assertRaisesRegex(AssertionError, "assert_foobar"):
1593            traced(torch.rand(4, 3))
1594        # verify the symbolically traced module is scriptable
1595        ms = torch.jit.script(m)
1596        with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
1597            ms(torch.rand(4, 3))
1598
1599    def test_fx_create_arg(self):
1600        class CustomArgObject:
1601            def __init__(self, x, y):
1602                self.x = x
1603                self.y = y
1604
1605            def __fx_create_arg__(self, tracer: torch.fx.Tracer):
1606                return tracer.create_node(
1607                    "call_function",
1608                    CustomArgObject,
1609                    args=(
1610                        tracer.create_arg(self.x),
1611                        tracer.create_arg(self.y),
1612                    ),
1613                    kwargs={},
1614                )
1615
1616        class HasCustomArgObjectWhenLeaf(torch.nn.Module):
1617            def forward(self, o: CustomArgObject):
1618                # Not normally traceable; good reason to make
1619                # this module a leaf.
1620                for x in o.x:
1621                    o.y += x
1622                return o.y
1623
1624        class Root(torch.nn.Module):
1625            def __init__(self) -> None:
1626                super().__init__()
1627                self.inner = HasCustomArgObjectWhenLeaf()
1628
1629            def forward(self, x, y):
1630                o = CustomArgObject(x, y)
1631                return self.inner(o)
1632
1633        class CreateArgTracer(torch.fx.Tracer):
1634            def is_leaf_module(self, m, module_qualified_name):
1635                return type(m) is HasCustomArgObjectWhenLeaf
1636
1637        m = Root()
1638        graph = CreateArgTracer().trace(m)
1639        gm = torch.fx.GraphModule(m, graph)
1640        assert "CustomArgObject(" in gm.code
1641
1642    def test_trace_fn_constant(self):
1643        some_constant = torch.rand(3, 4)
1644
1645        def add_const(x):
1646            return some_constant + x
1647
1648        traced = symbolic_trace(add_const)
1649
1650        input = torch.rand(3, 4)
1651        self.assertEqual(traced(input), add_const(input))
1652
1653    def test_copy_no_remap(self):
1654        traced = symbolic_trace(SimpleTest())
1655        g = traced.graph
1656        copied = torch.fx.Graph()
1657        for node in g.nodes:
1658            copied.node_copy(node)
1659        with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
1660            copied.lint()
1661
1662    def test_wrong_topo(self):
1663        graph : torch.fx.Graph = torch.fx.Graph()
1664        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1665        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1666        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1667        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1668        graph.output(d)
1669        nodes = list(graph.nodes)
1670        nodes[3].append(nodes[2])
1671        with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
1672            graph.lint()
1673
1674    def test_wrong_target_type(self):
1675        graph : torch.fx.Graph = torch.fx.Graph()
1676        with self.assertRaises(ValueError):
1677            n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
1678                              args=(), kwargs={})
1679
1680    def test_example_shape_prop(self):
1681        class TestCase(torch.nn.Module):
1682            def __init__(self) -> None:
1683                super().__init__()
1684                self.attr = torch.randn(3, 4)
1685                self.submod = torch.nn.Linear(4, 4)
1686
1687            def forward(self, x):
1688                return torch.neg(self.submod(x.relu() + self.attr))
1689        tc = TestCase()
1690        tc_traced = symbolic_trace(tc)
1691        ref_out = tc_traced(torch.rand(3, 4))
1692        shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
1693
1694        # Make sure we're testing all opcodes
1695        opcodes = set()
1696        output_shape : Optional[torch.Shape] = None
1697        output_stride : Optional[Tuple[int]] = None
1698        for node in tc_traced.graph.nodes:
1699            opcodes.add(node.op)
1700            if node.op == 'output':
1701                output_shape = node.args[0].meta['tensor_meta'].shape
1702                output_stride = node.args[0].meta['tensor_meta'].stride
1703        self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
1704                                   'call_module', 'output'})
1705
1706        # Test shape propagation and make sure results match actual
1707        self.assertEqual(output_shape, ref_out.shape)
1708        self.assertEqual(output_stride, ref_out.stride())
1709
1710    def test_shape_prop_layout(self):
1711        class ConvTest(torch.nn.Module):
1712            def __init__(self) -> None:
1713                super().__init__()
1714                self.conv_mod = torch.nn.Conv2d(5, 5, 3)
1715
1716            def forward(self, x):
1717                return self.conv_mod(x)
1718
1719        # contiguous layout
1720        test_mod = ConvTest()
1721        traced = symbolic_trace(test_mod)
1722        x = torch.randn(5, 5, 224, 224)
1723        shape_prop.ShapeProp(traced).propagate(x)
1724
1725        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1726                   for node in traced.graph.nodes)
1727
1728        x_channels_last = x.contiguous(memory_format=torch.channels_last)
1729        traced.to(memory_format=torch.channels_last)
1730        shape_prop.ShapeProp(traced).propagate(x_channels_last)
1731        for node in traced.graph.nodes:
1732            # NB: the implementation of conv may not preserve the memory format,
1733            # unfortunately. The best we can do is just check that the placeholder
1734            # node is channels-last
1735            if node.op in {'placeholder'}:
1736                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
1737
1738    def test_shape_prop_aggregate(self):
1739        class ReturnTwo(torch.nn.Module):
1740            def forward(self, x):
1741                return (3, torch.sum(x))
1742
1743        class UnderTest(torch.nn.Module):
1744            def __init__(self) -> None:
1745                super().__init__()
1746                self.rt = ReturnTwo()
1747
1748            def forward(self, x):
1749                return self.rt(x)
1750
1751        ut = UnderTest()
1752
1753        class RTTracer(torch.fx.Tracer):
1754            def is_leaf_module(self, m, module_qualified_name):
1755                return type(m) is ReturnTwo
1756
1757        graph = RTTracer().trace(ut)
1758        mod = torch.fx.GraphModule(ut, graph)
1759
1760        shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
1761
1762        for node in mod.graph.nodes:
1763            if node.op == 'call_module':
1764                assert 'tensor_meta' in node.meta
1765                tensor_meta = node.meta['tensor_meta']
1766                assert tensor_meta[0] == 3
1767                assert tensor_meta[1].shape == torch.Size([])
1768
1769    def test_shape_prop_layout_3d(self):
1770        class ConvTest3d(torch.nn.Module):
1771            def __init__(self) -> None:
1772                super().__init__()
1773                self.conv_mod = torch.nn.Conv3d(5, 5, 3)
1774
1775            def forward(self, x):
1776                return self.conv_mod(x)
1777
1778        test_mod_3d = ConvTest3d()
1779        traced_3d = symbolic_trace(test_mod_3d)
1780        x_3d = torch.randn(5, 5, 224, 224, 15)
1781        shape_prop.ShapeProp(traced_3d).propagate(x_3d)
1782        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1783                   for node in traced_3d.graph.nodes)
1784
1785        x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
1786        traced_3d.to(memory_format=torch.channels_last_3d)
1787        shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
1788        for node in traced_3d.graph.nodes:
1789            # NB: the implementation of conv may not preserve the memory format,
1790            # unfortunately. The best we can do is just check that the placeholder
1791            # node is channels-last
1792            if node.op in {'placeholder'}:
1793                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
1794
1795    def test_nn_module_stack(self):
1796        class SubModule(torch.nn.Module):
1797            def __init__(self) -> None:
1798                super().__init__()
1799                self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
1800
1801            def forward(self, x):
1802                return self.conv_mod(x)
1803
1804        class MyModule(torch.nn.Module):
1805            def __init__(self) -> None:
1806                super().__init__()
1807                self.sub_mod = SubModule()
1808
1809            def forward(self, x):
1810                return self.sub_mod(x)
1811
1812        m = MyModule()
1813        gm = torch.fx.symbolic_trace(m)
1814
1815        mod_stack = {}
1816        expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
1817                          ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
1818        for node in gm.graph.nodes:
1819            mod_stack = node.meta.get('nn_module_stack', {})
1820            if mod_stack:
1821                break
1822        stack_list = list(mod_stack.items())
1823        self.assertEqual(stack_list, expected_stack)
1824
1825    def test_transformer_preserves_nn_module_stack_for_get_attr(self):
1826        class M(torch.nn.Module):
1827            def __init__(self) -> None:
1828                super().__init__()
1829                self.weight = torch.nn.Parameter(torch.ones(1, 1))
1830
1831            def forward(self, x):
1832                return self.weight + x
1833
1834        tracer = torch.fx.Tracer()
1835        graph = tracer.trace(M())
1836        gm = GraphModule(tracer.root, graph)
1837        for node in gm.graph.nodes:
1838            if node.op == 'get_attr':
1839                node.meta["nn_module_stack"] = "self"
1840                node.meta["stack_trace"] = "stack_trace"
1841                node.meta["source_fn_stack"] = "source_fn_stack"
1842        new_gm = Transformer(gm).transform()
1843        for node in new_gm.graph.nodes:
1844            if node.op == 'get_attr':
1845                self.assertEqual(node.meta["nn_module_stack"], "self")
1846                self.assertEqual(node.meta["stack_trace"], "stack_trace")
1847                self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
1848
1849    def test_interpreter(self):
1850        class MyModule(torch.nn.Module):
1851            def __init__(self) -> None:
1852                super().__init__()
1853                self.param = torch.nn.Parameter(torch.rand(3, 4))
1854                self.linear = torch.nn.Linear(4, 5)
1855
1856            def forward(self, x):
1857                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1858
1859        m = MyModule()
1860        gm = torch.fx.symbolic_trace(m)
1861
1862        interpreter = Interpreter(gm)
1863        input = torch.randn(3, 4)
1864        self.assertEqual(interpreter.run(input), gm(input))
1865        self.assertEqual(interpreter.run(input), m(input))
1866
1867    def test_interpreter_other_graph(self):
1868        class MyModule(torch.nn.Module):
1869            def __init__(self) -> None:
1870                super().__init__()
1871                self.param = torch.nn.Parameter(torch.rand(3, 4))
1872                self.linear = torch.nn.Linear(4, 5)
1873
1874            def forward(self, x):
1875                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1876
1877        m = MyModule()
1878        gm = torch.fx.symbolic_trace(m)
1879
1880        interpreter = Interpreter(gm, graph=gm.graph)
1881        input = torch.randn(3, 4)
1882        self.assertEqual(interpreter.run(input), gm(input))
1883        self.assertEqual(interpreter.run(input), m(input))
1884
1885    def test_interpreter_run_node_override(self):
1886        class MyModule(torch.nn.Module):
1887            def __init__(self) -> None:
1888                super().__init__()
1889                self.param = torch.nn.Parameter(torch.rand(3, 4))
1890                self.linear = torch.nn.Linear(4, 5)
1891
1892            def forward(self, x):
1893                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1894
1895        m = MyModule()
1896        gm = torch.fx.symbolic_trace(m)
1897
1898        class RunNodeInterpreter(Interpreter):
1899            def __init__(self, module):
1900                super().__init__(module)
1901
1902            def run_node(self, n : Node) -> Any:
1903                result = super().run_node(n)
1904                n.cached_value = result
1905                return result
1906
1907        input = torch.randn(3, 4)
1908        RunNodeInterpreter(gm).run(input)
1909        for node in gm.graph.nodes:
1910            assert hasattr(node, 'cached_value')
1911
1912    def test_interpreter_onthefly_swap(self):
1913
1914        def fn(x):
1915            return torch.sigmoid(x).neg()
1916
1917        gm = torch.fx.symbolic_trace(fn)
1918
1919        class NegSigmSwapInterpreter(Interpreter):
1920            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1921                if target == torch.sigmoid:
1922                    return torch.neg(*args, **kwargs)
1923                return super().call_function(n)  # noqa: F821
1924
1925            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1926                if target == 'neg':
1927                    call_self, *args_tail = args
1928                    return call_self.sigmoid(*args_tail, **kwargs)
1929                return super().call_method(n)  # noqa: F821
1930
1931        input = torch.randn(3, 4)
1932        result = NegSigmSwapInterpreter(gm).run(input)
1933        self.assertEqual(result, torch.neg(input).sigmoid())
1934
1935    def test_interpreter_partial_eval(self):
1936        class MyModule(torch.nn.Module):
1937            def __init__(self) -> None:
1938                super().__init__()
1939                self.param = torch.nn.Parameter(torch.rand(3, 4))
1940                self.linear = torch.nn.Linear(4, 5)
1941
1942            def forward(self, x):
1943                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1944
1945        gm = torch.fx.symbolic_trace(MyModule())
1946        interp = Interpreter(gm)
1947        env = {}
1948        for node in gm.graph.nodes:
1949            if node.op == 'call_module' and node.target == 'linear':
1950                env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1951                break
1952        assert len(env) == 1
1953        x = torch.randn(3, 4)
1954        result = interp.run(x, initial_env=env)
1955        self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
1956
1957    def test_interpreter_star_args(self):
1958        def with_star_args(x, *args):
1959            return x + args[0]
1960
1961        gm = torch.fx.symbolic_trace(with_star_args)
1962        interp = Interpreter(gm)
1963        result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1964        self.assertEqual(result, torch.ones(3, 4) * 2.0)
1965
1966    @skipIfNoTorchVision
1967    def test_interpreter_noop_resnet18(self):
1968        rn18 = torchvision_models.resnet18()
1969        transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
1970        inp = torch.randn(5, 3, 224, 224)
1971        self.assertEqual(transformed(inp), rn18(inp))
1972
1973    @skipIfNoTorchVision
1974    def test_interpreter_gc_values(self):
1975        rn18 = torchvision_models.resnet18()
1976        interp = Interpreter(symbolic_trace(rn18))
1977        inp = torch.rand(5, 3, 224, 224)
1978        out = interp.run(inp)
1979        env_key_names = {n.name for n in interp.env.keys()}
1980        self.assertEqual(env_key_names, {'output'})
1981
1982    def test_interpreter_default_args(self):
1983        class Model(torch.nn.Module):
1984            def forward(self, x, y=3.14159):
1985                return x + y
1986
1987        model = Model()
1988        gm = torch.fx.symbolic_trace(model)
1989
1990        interp = Interpreter(gm)
1991        x = torch.randn(5, 3)
1992        out = interp.run(x)
1993        torch.testing.assert_close(out, x + 3.14159)
1994
1995    def test_interpreter_not_enough_args(self):
1996        class Model(torch.nn.Module):
1997            def forward(self, x, y):
1998                return x + y
1999
2000        model = Model()
2001        gm = torch.fx.symbolic_trace(model)
2002
2003        interp = Interpreter(gm)
2004        x = torch.randn(5, 3)
2005        with self.assertRaisesRegex(RuntimeError,
2006                                    'Expected positional argument for parameter y, but one was not passed in'):
2007            out = interp.run(x)
2008
2009    def test_transformer_noop(self):
2010        class MyModule(torch.nn.Module):
2011            def __init__(self) -> None:
2012                super().__init__()
2013                self.param = torch.nn.Parameter(torch.rand(3, 4))
2014                self.linear = torch.nn.Linear(4, 5)
2015
2016            def forward(self, x):
2017                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
2018
2019        m = MyModule()
2020        gm = torch.fx.symbolic_trace(m)
2021
2022        new_gm = Transformer(gm).transform()
2023
2024        input = torch.randn(3, 4)
2025        self.assertEqual(new_gm(input), gm(input))
2026
2027    def test_transformer_op_swap(self):
2028
2029        def fn(x):
2030            return torch.sigmoid(x).neg()
2031
2032        gm = torch.fx.symbolic_trace(fn)
2033
2034        class NegSigmSwapXformer(Transformer):
2035            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2036                if target == torch.sigmoid:
2037                    return torch.neg(*args, **kwargs)
2038                return super().call_function(n)  # noqa: F821
2039
2040            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2041                if target == 'neg':
2042                    call_self, *args_tail = args
2043                    return call_self.sigmoid(*args_tail, **kwargs)
2044                return super().call_method(n)  # noqa: F821
2045
2046        transformed = NegSigmSwapXformer(gm).transform()
2047        input = torch.randn(3, 4)
2048        self.assertEqual(transformed(input), torch.neg(input).sigmoid())
2049
2050    def test_transformer_multi_outputs(self):
2051        class MyModule(torch.nn.Module):
2052            def __init__(self) -> None:
2053                super().__init__()
2054                self.param = torch.nn.Parameter(torch.rand(3, 4))
2055                self.linear = torch.nn.Linear(4, 5)
2056
2057            def forward(self, x):
2058                x = x + self.param
2059                out = self.linear(x)
2060                return x, out
2061
2062        m = MyModule()
2063        gm = torch.fx.symbolic_trace(m)
2064
2065        new_gm = Transformer(gm).transform()
2066
2067        input = torch.randn(3, 4)
2068        self.assertEqual(new_gm(input), gm(input))
2069
2070    def test_fn_type_annotations(self):
2071        class Foo(torch.nn.Module):
2072            def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
2073                return {'a': p.x + p.y + z + i}
2074
2075        foo_scripted = torch.jit.script(Foo())
2076        foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2077
2078        fxed = symbolic_trace(Foo())
2079        fxed_scripted = torch.jit.script(fxed)
2080        fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2081
2082    def test_fn_type_annotation_empty(self):
2083        def forward(a : List[torch.Tensor]):
2084            return a[0]
2085        torch.jit.script(symbolic_trace(forward))
2086
2087    def test_wrapped_method(self):
2088        def wrap_with_relu(fn):
2089            @functools.wraps(fn)
2090            def wrapper(*args, **kwargs):
2091                return torch.relu(fn(*args, **kwargs))
2092            return wrapper
2093
2094        class Foo(torch.nn.Module):
2095            @wrap_with_relu
2096            def forward(self, x, w):
2097                return torch.matmul(x, w)
2098
2099        f = Foo()
2100        traced = symbolic_trace(f)
2101        x, w = torch.rand(3, 4), torch.rand(4, 4)
2102        self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
2103
2104    def test_empty_graph_codegen(self):
2105        graph = torch.fx.Graph()
2106        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2107        self.assertEqual(gm(), None)
2108
2109    def test_sequential(self):
2110        m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
2111        gm = torch.fx.symbolic_trace(m)
2112        gm_copy = copy.deepcopy(gm)
2113
2114    def test_ctx_mgr(self):
2115        @contextlib.contextmanager
2116        def do_nothing():
2117            yield
2118
2119        class M(torch.nn.Module):
2120            @do_nothing()
2121            def forward(self, x):
2122                return torch.relu(x)
2123
2124        m = M()
2125        self.checkGraphModule(m, (torch.rand(3, 4),))
2126
2127    def test_typename_print(self):
2128        graph : torch.fx.Graph = torch.fx.Graph()
2129        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2130        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
2131                                              type_expr=List[float])
2132        output : torch.fx.Node = graph.output(b)
2133
2134        self.assertTrue('typing.List[float]' in str(graph))
2135
2136    def test_layout(self):
2137        class M(torch.nn.Module):
2138            def forward(self, x):
2139                return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
2140
2141        traced = symbolic_trace(M())
2142        x = torch.rand(5, 9, 3, 4)
2143        self.assertEqual(traced(x), torch.zeros_like(x))
2144
2145    def test_ellipsis(self):
2146        class M(torch.nn.Module):
2147            def forward(self, x, y):
2148                return x + y[:, 1:10, ...]
2149
2150        traced = symbolic_trace(M())
2151        x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
2152        self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
2153
2154    def test_inf_nan(self):
2155        class FooMod(torch.nn.Module):
2156            def forward(self, x):
2157                return x + float('inf'), x + float('-inf'), x + float('nan')
2158
2159        fm = FooMod()
2160        self.checkGraphModule(fm, (torch.rand(3, 4),))
2161
2162    def test_inf_nan_kwds(self):
2163        graph : torch.fx.Graph = torch.fx.Graph()
2164        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2165        b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
2166        c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
2167        graph.output((b, c))
2168
2169        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2170        x = torch.rand(3, 4)
2171        self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
2172
2173    def test_deepcopy_recursion_depth(self):
2174        depth = sys.getrecursionlimit() + 20
2175
2176        g = torch.fx.Graph()
2177        x = g.placeholder('x')
2178        for i in range(depth):
2179            x = g.call_function(torch.relu, (x,))
2180        g.output(x)
2181
2182        copied_graph = copy.deepcopy(g)
2183
2184        val_map = {}
2185        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2186            val_map[orig_node] = new_node
2187
2188        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2189            orig_users = set(orig_node.users.keys())
2190            orig_users_equiv = {val_map[u] for u in orig_users}
2191            new_users = set(new_node.users.keys())
2192            self.assertEqual(orig_users_equiv, new_users)
2193
2194    @skipIfNoTorchVision
2195    def test_replace_uses(self):
2196        rn18 = torchvision_models.resnet18()
2197
2198        class LowerReluTracer(torch.fx.Tracer):
2199            def is_leaf_module(self, m : torch.nn.Module, qualname : str):
2200                if isinstance(m, torch.nn.ReLU):
2201                    return False
2202                return super().is_leaf_module(m, qualname)
2203
2204        rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
2205
2206        to_erase = []
2207        for node in rn18_traced.graph.nodes:
2208            if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
2209                kwargs = node.kwargs.copy()
2210                # Neg doesn't have in-place
2211                kwargs.pop('inplace')
2212                with rn18_traced.graph.inserting_before(node):
2213                    new_node = rn18_traced.graph.call_function(
2214                        the_function=torch.neg, args=node.args, kwargs=node.kwargs)
2215                node.replace_all_uses_with(replace_with=new_node)
2216                to_erase.append(node)
2217
2218        for node in to_erase:
2219            rn18_traced.graph.erase_node(node)
2220
2221    def test_replace_input(self):
2222        graph : torch.fx.Graph = torch.fx.Graph()
2223        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2224        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2225        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2226        output : torch.fx.Node = graph.output(b)
2227
2228        b.replace_input_with(x, y)
2229
2230        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2231
2232        input_x = torch.randn(33, 44)
2233        input_y = torch.randn(11, 22)
2234        self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
2235
2236    def test_insertion_point(self):
2237        graph : torch.fx.Graph = torch.fx.Graph()
2238        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2239        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2240        output : torch.fx.Node = graph.output(b)
2241
2242        with graph.inserting_before(b):
2243            neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2244            _, *relu_args = b.args
2245            b.args = (neg, *relu_args)
2246
2247        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2248
2249        input = torch.randn(33, 44)
2250        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2251
2252    def test_update_args_api(self):
2253        graph : torch.fx.Graph = torch.fx.Graph()
2254        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2255        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2256        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2257        output : torch.fx.Node = graph.output(b)
2258
2259        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2260        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2261        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2262
2263        b.update_arg(0, y)
2264        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2265        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2266
2267    def test_update_kwargs_api(self):
2268        graph : torch.fx.Graph = torch.fx.Graph()
2269        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2270        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2271        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
2272        output : torch.fx.Node = graph.output(b)
2273
2274        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2275        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2276        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2277
2278        b.update_kwarg('input', y)
2279        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2280        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2281
2282    def test_immutable_list_pytree_ops(self):
2283        rand_tensor = torch.randn(5, 3)
2284        l = immutable_list([3, [rand_tensor, 42]])
2285
2286        flattened, spec = pytree.tree_flatten(l)
2287        assert flattened == [3, rand_tensor, 42]
2288
2289        unflattened = pytree.tree_unflatten(flattened, spec)
2290        assert unflattened == l
2291        assert isinstance(unflattened, immutable_list)
2292
2293    def test_immutable_dict_pytree_ops(self):
2294        rand_tensor = torch.randn(5, 3)
2295        d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
2296
2297        flattened, spec = pytree.tree_flatten(d)
2298        assert flattened == [3, rand_tensor, 42]
2299
2300        unflattened = pytree.tree_unflatten(flattened, spec)
2301        assert unflattened == d
2302        assert isinstance(unflattened, immutable_dict)
2303
2304    def test_move_before(self):
2305        graph : torch.fx.Graph = torch.fx.Graph()
2306        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2307        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2308        output : torch.fx.Node = graph.output(b)
2309
2310        neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2311        _, *relu_args = b.args
2312        b.args = (neg, *relu_args)
2313        b.prepend(neg)
2314
2315        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2316
2317        input = torch.randn(33, 44)
2318        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2319
2320    def test_prepend_self(self):
2321        graph : torch.fx.Graph = torch.fx.Graph()
2322        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2323        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2324        output : torch.fx.Node = graph.output(b)
2325
2326        b.prepend(b)
2327        x.append(b)
2328        self.assertEqual(len(graph.nodes), 3)
2329
2330    def test_erase_node_error(self):
2331        st = SimpleTest()
2332        traced = symbolic_trace(st)
2333
2334        for node in traced.graph.nodes:
2335            # Test deleting with uses both in another Node and at the output
2336            if node.target in [operator.add, torch.relu]:
2337                with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
2338                    traced.graph.erase_node(node)
2339
2340    def test_copy_it(self):
2341        d = immutable_dict([(3, 4), (5, 6)])
2342        l = immutable_list([(3, 4), (5, 6)])
2343
2344        self.assertEqual(d, deepcopy(d))
2345        self.assertEqual(l, deepcopy(l))
2346
2347    def test_get_torch_func_signature(self):
2348        for key in dir(torch):
2349            obj = getattr(torch, key)
2350            if callable(obj):
2351                schemas = get_signature_for_torch_op(obj)
2352
2353    def test_find_uses(self):
2354        graph = torch.fx.Graph()
2355        x = torch.fx.Proxy(graph.placeholder('x'))
2356
2357        y = torch.relu(x)
2358        z = x + x
2359        u = torch.neg(x)
2360        graph.output((y + z + u).node)
2361        graph.lint()
2362
2363        users_of_x = x.node.users
2364        self.assertEqual(len(users_of_x), 3)
2365        expected_ops = {'relu', 'add', 'neg'}
2366        for use in users_of_x:
2367            assert any(use.name.startswith(prefix) for prefix in expected_ops)
2368
2369    def test_inline_graph(self):
2370        class InlineInto(torch.nn.Module):
2371            def forward(self, x):
2372                return torch.relu(x)
2373
2374        class ToInline(torch.nn.Module):
2375            def forward(self, x):
2376                return torch.neg(x)
2377
2378        inline_into = symbolic_trace(InlineInto())
2379        to_inline = symbolic_trace(ToInline())
2380
2381        combined_graph = torch.fx.Graph()
2382        output_node = combined_graph.graph_copy(inline_into.graph, {})
2383
2384        input_node = next(iter(to_inline.graph.nodes))
2385        assert input_node and input_node.op == 'placeholder'
2386
2387        val_map = {input_node : output_node}
2388        output = combined_graph.graph_copy(to_inline.graph, val_map)
2389        combined_graph.output(output)
2390
2391        combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
2392
2393        input = torch.rand(3, 4)
2394        self.assertEqual(combined_module(input), input.relu().neg())
2395
2396    def test_multi_insert_point(self):
2397        graph = torch.fx.Graph()
2398        x = torch.fx.Proxy(graph.placeholder('x'))
2399        relu = torch.relu(x)
2400
2401        with graph.inserting_before(relu.node):
2402            y = torch.neg(x)
2403            z = torch.tanh(y)
2404
2405        graph.output((relu.node, z.node))
2406        graph.lint()
2407
2408        expected_ops = ['x', 'neg', 'tanh', 'relu']
2409        for node, expected in zip(graph.nodes, expected_ops):
2410            assert expected in node.name
2411
2412    def test_reassign_args_kwargs_uses(self):
2413        graph = torch.fx.Graph()
2414        x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
2415        z = x + y
2416        zed = z + z + z
2417        graph.output(zed.node)
2418        graph.lint()
2419
2420        # zed = z + z + z -> zed = z + z + x
2421        zed.node.args = (zed.node.args[0], x.node)
2422        self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
2423
2424        # z = x + y -> z = y + y
2425        z.node.args = (y.node, y.node)
2426        self.assertEqual(list(x.node.users.keys()), [zed.node])
2427
2428    def test_trace_function(self):
2429        def foo(x, y):
2430            return torch.relu(x) + y
2431
2432        x, y = torch.randn(3, 4), torch.randn(3, 4)
2433        self.checkGraphModule(foo, (x, y))
2434
2435    def test_trace_return_dataclass(self):
2436        """
2437        Test case for Module that return dataclass
2438        """
2439        from dataclasses import dataclass
2440
2441        @dataclass
2442        class MyOutput:
2443            foo: torch.Tensor
2444            bar: torch.Tensor
2445
2446        class ModuleReturnDataclass(torch.nn.Module):
2447            def forward(self, d : torch.Tensor):
2448                return MyOutput(foo=d + d, bar=d * 3)
2449
2450        module = ModuleReturnDataclass()
2451        traced_graph = symbolic_trace(module).graph
2452        print(traced_graph)
2453
2454        gm = GraphModule(module, traced_graph)
2455        x = torch.rand(1)
2456
2457        self.assertEqual(module(x), gm(x))
2458
2459    def test_trace_return_dataclass_nested(self):
2460        """
2461        Test case for Module that return dataclass
2462        """
2463        from dataclasses import dataclass
2464
2465        @dataclass
2466        class MyOutput:
2467            foo: torch.Tensor
2468            bar: torch.Tensor
2469
2470        class ModuleReturnDataclass(torch.nn.Module):
2471            def forward(self, d : torch.Tensor):
2472                return MyOutput(foo=d + d, bar=d * 3)
2473
2474        class CallsModule(torch.nn.Module):
2475            def __init__(self) -> None:
2476                super().__init__()
2477                self.m = ModuleReturnDataclass()
2478
2479            def forward(self, x):
2480                tmp = self.m(x)
2481                return MyOutput(foo=tmp.foo, bar=tmp.bar)
2482
2483        module = CallsModule()
2484        traced_graph = symbolic_trace(module).graph
2485        print(traced_graph)
2486
2487        gm = GraphModule(module, traced_graph)
2488        x = torch.rand(1)
2489
2490        self.assertEqual(module(x), gm(x))
2491
2492    def test_trace_return_namedtuple(self):
2493        """
2494        Test case for Module that return namedtuple
2495        """
2496        class MyOutput(NamedTuple):
2497            foo: torch.Tensor
2498            bar: torch.Tensor
2499
2500        class ModuleReturnNamedTuple(torch.nn.Module):
2501            def forward(self, d : torch.Tensor):
2502                return MyOutput(foo=d, bar=d)
2503
2504        module = ModuleReturnNamedTuple()
2505
2506        traced_graph = symbolic_trace(module).graph
2507        print(traced_graph)
2508
2509        gm = GraphModule(module, traced_graph)
2510        x = torch.rand(1)
2511
2512        self.assertEqual(module(x), gm(x))
2513
2514    def test_trace_dict_int_keys(self):
2515        class ModWithDictArg(torch.nn.Module):
2516            def forward(self, d : Dict[int, torch.Tensor]):
2517                return d[42]
2518
2519        class CallsModWithDict(torch.nn.Module):
2520            def __init__(self) -> None:
2521                super().__init__()
2522                self.m = ModWithDictArg()
2523
2524            def forward(self, x):
2525                return self.m({42: x})
2526
2527        class MyTracer(torch.fx.Tracer):
2528            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2529                return isinstance(m, ModWithDictArg)
2530
2531        traced_graph = MyTracer().trace(CallsModWithDict())
2532
2533    def test_trace_dict_proxy_keys(self):
2534        class ModWithDictArg(torch.nn.Module):
2535            def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
2536                return d[42]
2537
2538        class CallsModWithDict(torch.nn.Module):
2539            def __init__(self) -> None:
2540                super().__init__()
2541                self.m = ModWithDictArg()
2542
2543            def forward(self, x):
2544                return self.m({x: x})
2545
2546        class MyTracer(torch.fx.Tracer):
2547            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2548                return isinstance(m, ModWithDictArg)
2549
2550        with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
2551            traced_graph = MyTracer().trace(CallsModWithDict())
2552
2553    def test_module_deepcopy_edit_nodes(self):
2554        class Foo(torch.nn.Module):
2555            def forward(self, x):
2556                return torch.relu(x)
2557
2558        traced1 = symbolic_trace(Foo())
2559        copied = copy.deepcopy(traced1)
2560
2561        for node in copied.graph.nodes:
2562            if node.target == torch.relu:
2563                node.target = torch.neg
2564
2565        copied.recompile()
2566        traced1.recompile()
2567
2568        x = torch.randn(15, 15)
2569        torch.testing.assert_close(traced1(x), torch.relu(x))
2570        torch.testing.assert_close(copied(x), torch.neg(x))
2571
2572    def test_direct_param_use(self):
2573        class TransposeTest(torch.nn.Module):
2574            def __init__(self) -> None:
2575                super().__init__()
2576                self.b = torch.nn.Parameter(torch.rand(4, 3))
2577
2578            def forward(self, x):
2579                return self.b
2580
2581        class Foo(torch.nn.Module):
2582            def __init__(self) -> None:
2583                super().__init__()
2584                self.a = TransposeTest()
2585
2586            def forward(self, x):
2587                return self.a.b, self.a.b.t(), self.a.b.view(12)
2588
2589        traced = torch.fx.symbolic_trace(Foo())
2590        assert all('constant' not in node.target for node in traced.graph.nodes)
2591
2592    def test_single_default_arg(self):
2593        class M(torch.nn.Module):
2594            def forward(self, y=1):
2595                return y
2596
2597        m = M()
2598        self.checkGraphModule(m, ())
2599        self.checkGraphModule(m, (3,))
2600
2601    def test_multiple_default_args(self):
2602        class M(torch.nn.Module):
2603            def forward(self, y=1, z=2):
2604                return y + z
2605
2606        m = M()
2607        self.checkGraphModule(m, ())
2608        self.checkGraphModule(m, (3,))
2609        self.checkGraphModule(m, (3, 4))
2610
2611    def test_regular_and_default_args(self):
2612        class M(torch.nn.Module):
2613            def forward(self, x, y=1):
2614                return x + y
2615
2616        m = M()
2617        self.checkGraphModule(m, (2,))
2618        self.checkGraphModule(m, (2, 3))
2619
2620    def test_string_literal_return(self):
2621        class M(torch.nn.Module):
2622            def forward(self):
2623                return "foo"
2624
2625        m = M()
2626        self.checkGraphModule(m, ())
2627
2628    def test_namedtuple_return_qualname(self):
2629        class NamedTupReturn(torch.nn.Module):
2630            def forward(self, x):
2631                return MyNamedTup(x, x)
2632
2633        traced = symbolic_trace(NamedTupReturn())
2634        input = torch.rand(3, 4)
2635        self.assertEqual(traced(input), MyNamedTup(input, input))
2636
2637    def test_update_args_kwargs_yells_at_you(self):
2638        symtraced = symbolic_trace(SimpleTest())
2639        node = next(iter(symtraced.graph.nodes))
2640        with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
2641            node.__update_args_kwargs((), {})
2642
2643    def test_torchbind_class_attribute_in_fx(self):
2644        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2645            self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
2646
2647        class FooBar1234(torch.nn.Module):
2648            def __init__(self) -> None:
2649                super().__init__()
2650                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
2651
2652            def forward(self):
2653                return self.f.top()
2654
2655        m = FooBar1234()
2656        self.checkGraphModule(m, ())
2657
2658    def test_torchbind_class_attribute_in_fx_tensor_arg(self):
2659        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2660            self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
2661
2662        class FooBar2341(torch.nn.Module):
2663            def __init__(self) -> None:
2664                super().__init__()
2665                self.f = torch.classes._TorchScriptTesting._ReLUClass()
2666
2667            def forward(self, x):
2668                return self.f.run(x)
2669
2670        m = FooBar2341()
2671
2672        traced = symbolic_trace(m)
2673        input = torch.randn(3, 4)
2674        self.assertEqual(traced(input), m(input))
2675
2676        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2677
2678    def test_script_method_trace(self):
2679        class Scripted(torch.nn.Module):
2680            def forward(self, x):
2681                return torch.relu(x)
2682
2683        class Holder(torch.nn.Module):
2684            def __init__(self) -> None:
2685                super().__init__()
2686                self.s = torch.jit.script(Scripted())
2687
2688            def forward(self, x):
2689                return self.s(x)
2690
2691        h = Holder()
2692        traced = symbolic_trace(h)
2693        input = torch.randn(3, 4)
2694        self.assertEqual(traced(input), h(input))
2695
2696        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2697
2698    def test_namedtuple_return_trace(self):
2699        class NamedTupReturn(torch.nn.Module):
2700            def forward(self, x):
2701                return Pair(x, x)
2702
2703        traced = symbolic_trace(NamedTupReturn())
2704        input = torch.rand(3, 4)
2705        self.assertEqual(traced(input), Pair(input, input))
2706
2707    def test_named_tuple_inlined(self):
2708        class NamedTupMod(torch.nn.Module):
2709            def forward(self, inp):
2710                return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
2711
2712        m = NamedTupMod()
2713        input = torch.rand(3, 4)
2714        ref = m(input)
2715        traced = symbolic_trace(m)
2716
2717        res = traced(input)
2718        self.assertEqual(ref, res)
2719
2720        # Check Pair NamedTuple works when inlined into the function call.
2721        ph = call_func = None
2722        for node in traced.graph.nodes:
2723            if node.op == "placeholder":
2724                ph = node
2725            elif node.op == "call_function" and node.target == wrapped_named_tup:
2726                node.update_arg(0, Pair(ph, 1.2))
2727                node.update_kwarg("p2", Pair(3.4, ph))
2728                call_func = node
2729                break
2730        self.assertTrue(call_func is not None)
2731        self.assertTrue(isinstance(call_func.args[0], Pair))
2732        self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
2733        self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
2734        self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
2735
2736        traced.graph.eliminate_dead_code()
2737        traced.recompile()
2738        res = traced(input)
2739        self.assertEqual(ref, res)
2740
2741    def test_return_type_exists(self):
2742        class ReturnTypeModule(torch.nn.Module):
2743            def other(self, x: List[str]) -> List[str]:
2744                return x
2745
2746            def forward(self, x: List[str]) -> List[str]:
2747                return self.other(x)
2748
2749        traced = symbolic_trace(ReturnTypeModule())
2750        self.assertIn("-> typing_List[str]", traced._code)
2751        scripted = torch.jit.script(traced)
2752        self.assertIn("-> List[str]", scripted.code)
2753
2754    def getitem_inner(self):
2755        class GetItemBase(torch.nn.Module):
2756            def __init__(self) -> None:
2757                super().__init__()
2758                self.pe = torch.nn.Buffer(torch.randn(8, 8))
2759
2760        class GetItem1(GetItemBase):
2761            def forward(self, x):
2762                return self.pe[:, :x.size(0)]
2763
2764        class GetItem2(GetItemBase):
2765            def forward(self, x):
2766                return self.pe[x.size(0)]
2767
2768        class GetItem3(GetItemBase):
2769            def forward(self, x):
2770                return self.pe[4]  # fx creates `self._tensor_constant0` here
2771
2772        self.checkGraphModule(GetItem1(), [torch.zeros(4)])
2773        self.checkGraphModule(GetItem2(), [torch.zeros(4)])
2774        self.checkGraphModule(GetItem3(), [torch.zeros(4)])
2775
2776    @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
2777                         "Will be checked in test_getitem_subproc")
2778    def test_getitem(self):
2779        self.getitem_inner()
2780
2781    def test_getitem_subproc(self):
2782        # need to run this test in a subproc to work around:
2783        #   https://github.com/pytorch/pytorch/issues/50710
2784        proc = Process(target=run_getitem_target)
2785        proc.start()
2786        proc.join()
2787        self.assertEqual(proc.exitcode, 0)
2788
2789    def test_user_friendly_call_provenance_with_function(self):
2790        def fn(x):
2791            return wrapper_fn(x)
2792
2793        traced = torch.fx.symbolic_trace(fn)
2794
2795        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2796                                    "being compiled since it was called"
2797                                    " from 'fn.forward'"):
2798            scripted = torch.jit.script(traced)
2799
2800    def test_user_friendly_call_provenance_with_module(self):
2801        class M(torch.nn.Module):
2802            def forward(self, x):
2803                return wrapper_fn(x)
2804
2805        traced = torch.fx.symbolic_trace(M())
2806
2807        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2808                                    "being compiled since it was called"
2809                                    " from 'M.forward'"):
2810            scripted = torch.jit.script(traced)
2811
2812    def test_snake_case(self):
2813        class M(torch.nn.Module):
2814            def __init__(self) -> None:
2815                super().__init__()
2816                self.activations = torch.nn.ModuleDict([
2817                    ["snake_case", torch.nn.ReLU()],
2818                    ["PascalCase", torch.nn.LeakyReLU()],
2819                    ["ALL_CAPS", torch.nn.PReLU()]
2820                ])
2821
2822            def forward(self, x):
2823                a = self.activations["snake_case"](x)
2824                b = self.activations["PascalCase"](x)
2825                c = self.activations["ALL_CAPS"](x)
2826                return a, b, c
2827
2828        traced = symbolic_trace(M())
2829
2830        check = [
2831            ("activations_snake_case", "activations.snake_case"),
2832            ("activations_pascal_case", "activations.PascalCase"),
2833            ("activations_all_caps", "activations.ALL_CAPS")
2834        ]
2835
2836        i = 0
2837        for node in traced.graph.nodes:
2838            if node.op == "placeholder" or node.op == "output":
2839                continue
2840            name = check[i][0]
2841            target = check[i][1]
2842            self.assertEqual(name, node.name)
2843            self.assertEqual(target, node.target)
2844            i += 1
2845        self.assertEqual(i, 3)
2846
2847    def test_no_mutation(self):
2848        from torch.fx.immutable_collections import immutable_list
2849        x = immutable_list([3, 4])
2850        with self.assertRaisesRegex(NotImplementedError, "new_args"):
2851            x[0] = 4
2852
2853    def test_partial_trace(self):
2854        class Foo(torch.nn.Module):
2855            def forward(self, x, y):
2856                if y:
2857                    return 2 * x
2858                else:
2859                    return x
2860        mod = Foo()
2861        mod_true = symbolic_trace(mod, concrete_args={'y': True})
2862        mod_false = symbolic_trace(mod, concrete_args={'y': False})
2863        self.assertEqual(mod_true(3, True), 6)
2864        print(mod_true.code)
2865        assert any(i.target == torch._assert for i in mod_true.graph.nodes)
2866        with self.assertRaises(AssertionError):
2867            mod_true(3, False)
2868        self.assertEqual(mod_false(3, False), 3)
2869        with self.assertRaises(AssertionError):
2870            mod_false(3, True)
2871
2872        def f_higher(a, f):
2873            return f(a)
2874
2875        nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
2876        self.assertEqual(nf(3, lambda x: x * 2), 6)
2877
2878    def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
2879        class M(torch.nn.Module):
2880            def __init__(self) -> None:
2881                super().__init__()
2882                self.W = torch.nn.Parameter(torch.randn(5))
2883
2884            def forward(self, x):
2885                return torch.dot(self.W, x)
2886
2887        traced = torch.fx.symbolic_trace(M())
2888
2889        out = [n for n in traced.graph.nodes if n.op == "output"][-1]
2890        with traced.graph.inserting_before(out):
2891            relu_out = traced.graph.call_method(method_name='relu',
2892                                                args=(out.args[0],))
2893        out.args = (relu_out,)
2894
2895        traced.recompile()
2896
2897        with self.capture_stderr() as captured:
2898            with self.assertRaises(TypeError):
2899                traced(5)
2900
2901        self.assertRegex(captured[0],
2902                         r"Call using an FX-traced Module, line .* of the "
2903                         r"traced Module's generated forward function:")
2904
2905    def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
2906        class M(torch.nn.Module):
2907            def __init__(self) -> None:
2908                super().__init__()
2909                self.linear = torch.nn.Linear(3, 4)
2910
2911            def forward(self, x):
2912                return self.linear(x)
2913
2914        traced = torch.fx.symbolic_trace(M())
2915
2916        # Do not change this to `capture_stderr` or another context
2917        # manager without ensuring that the output is as expected
2918        try:
2919            traced(torch.rand(5, 5))
2920        except RuntimeError:
2921            captured = traceback.format_exc()
2922
2923        self.assertNotRegex(captured,
2924                            r"Call using an FX-traced Module, line .* of the "
2925                            r"traced Module's generated forward function:")
2926
2927    def test_graph_module_replicate_for_dp(self):
2928        class Foo(torch.nn.Module):
2929            def forward(self, x):
2930                return torch.relu(x)
2931
2932        gm = torch.fx.symbolic_trace(Foo())
2933
2934        x = torch.randn(5, 3)
2935        out = gm(x)
2936
2937        replica = gm._replicate_for_data_parallel()
2938        out_replica = replica(x)
2939
2940        torch.testing.assert_close(out_replica, out)
2941
2942    def test_ast_rewriter_rewrites_assert(self):
2943        class M(torch.nn.Module):
2944            def forward(self, x: torch.Tensor, y: int, z: int):
2945                assert y == z
2946                return torch.add(x, x)
2947
2948        ast_rewriter = RewritingTracer()
2949        graph = ast_rewriter.trace(M())
2950        traced = GraphModule(ast_rewriter.root, graph, "gm")
2951
2952        traced.graph.lint()
2953
2954    def test_ast_rewriter_rewrites_assert_with_message(self):
2955        class M(torch.nn.Module):
2956            def forward(self, x: torch.Tensor, y: int, z: int):
2957                assert y == z, "msg"
2958                return torch.add(x, x)
2959
2960        ast_rewriter = RewritingTracer()
2961        graph = ast_rewriter.trace(M())
2962        traced = GraphModule(ast_rewriter.root, graph, "gm")
2963
2964        traced.graph.lint()
2965
2966    def test_throw_out_variant(self):
2967        def foo(x):
2968            y = torch.rand_like(x)
2969            torch.sigmoid(x, out=y)
2970            return y
2971
2972        class MyTracer(torch.fx.Tracer):
2973            check_mutable_operations = True
2974
2975        tracer = MyTracer()
2976        with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
2977            traced_graph = tracer.trace(foo)
2978
2979    def test_ast_rewriter_reassigns_submodules(self):
2980        class M(torch.nn.Module):
2981            def __init__(self) -> None:
2982                super().__init__()
2983                self.bn = torch.nn.BatchNorm2d(100)
2984
2985            def forward(self, x: torch.Tensor):
2986                return torch.add(x, x)
2987
2988        ast_rewriter = RewritingTracer()
2989        graph = ast_rewriter.trace(M())
2990        traced = GraphModule(ast_rewriter.root, graph, "gm")
2991
2992        traced.graph.lint()
2993
2994    def test_ast_rewriter_wrap(self):
2995        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
2996
2997        def to_trace(y):
2998            return (
2999                a_lifted_leaf((4, y), 3)
3000                + a_lifted_leaf((3, 4), 5)
3001                + a_lifted_leaf((y, y), y)
3002            )
3003
3004        ast_rewriter = RewritingTracer()
3005        graph = ast_rewriter.trace(to_trace)
3006        traced = GraphModule(ast_rewriter.root, graph, "gm")
3007
3008        self.assertIn("a_lifted_leaf", traced.code)
3009        self.assertEqual(27, traced(2))
3010        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
3011
3012    def test_ast_rewriter_wrap_fn_directly(self):
3013        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
3014
3015        def to_trace(y):
3016            return (
3017                a_lifted_leaf2((4, y), 3)
3018                + a_lifted_leaf2((3, 4), 5)
3019                + a_lifted_leaf2((y, y), y)
3020            )
3021
3022        ast_rewriter = RewritingTracer()
3023        graph = ast_rewriter.trace(to_trace)
3024        traced = GraphModule(ast_rewriter.root, graph, "gm")
3025
3026        self.assertIn("a_lifted_leaf2", traced.code)
3027        self.assertEqual(27, traced(2))
3028        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
3029
3030    def test_profiler_ranges_side_effect(self):
3031        g = torch.fx.Graph()
3032        handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',))
3033        g.call_function(torch.ops.profiler._record_function_exit, (handle,))
3034        g.output(None)
3035
3036        found_targets = {}
3037        for node in g.nodes:
3038            if node.op == 'call_function':
3039                found_targets.setdefault(node.target)
3040        self.assertEqual(
3041            list(found_targets.keys()),
3042            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3043        )
3044
3045        g.eliminate_dead_code()
3046        found_targets = {}
3047        for node in g.nodes:
3048            if node.op == 'call_function':
3049                found_targets.setdefault(node.target)
3050        self.assertEqual(
3051            list(found_targets.keys()),
3052            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3053        )
3054
3055    def test_ast_rewriter_wrapped_via_decorator(self):
3056        class F(torch.nn.Module):
3057            def forward(self, x):
3058                return wrapped_via_decorator(x)
3059
3060        ast_rewriter = RewritingTracer()
3061        graph = ast_rewriter.trace(F())
3062        traced = GraphModule(ast_rewriter.root, graph, "gm")
3063
3064        self.assertIn("wrapped_via_decorator", traced.code)
3065        self.assertEqual(traced(0), 1)
3066        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3067        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3068
3069    def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
3070        self.assertEqual(wrapped_via_decorator(0), 1)
3071
3072        def to_trace(y):
3073            return wrapped_via_decorator(y)
3074
3075        ast_rewriter = RewritingTracer()
3076        graph = ast_rewriter.trace(to_trace)
3077        traced = GraphModule(ast_rewriter.root, graph, "gm")
3078
3079        self.assertIn("wrapped_via_decorator", traced.code)
3080        self.assertEqual(traced(0), 1)
3081        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3082        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3083
3084        transformed = torch.fx.Transformer(traced).transform()
3085        self.assertIn("wrapped_via_decorator", transformed.code)
3086        self.assertEqual(transformed(0), 1)
3087        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3088        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3089
3090    def test_ast_rewriter_wrap_with_submodule(self):
3091        class M(torch.nn.Module):
3092            def __init__(self) -> None:
3093                super().__init__()
3094                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3095
3096            def forward(self, x: torch.Tensor):
3097                return wrapped_with_submodule(x, self.batchnorm1d)
3098
3099        ast_rewriter = RewritingTracer()
3100        graph = ast_rewriter.trace(M())
3101        traced = GraphModule(ast_rewriter.root, graph, "gm")
3102
3103        self.assertIn("wrapped_with_submodule", traced.code)
3104
3105        input = torch.rand(3, 2)
3106        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3107        self.assertEqual(ref_batchnorm1d(input), traced(input))
3108
3109    def test_submodule_manipulation_API(self):
3110        class C(torch.nn.Module):
3111            def __init__(self) -> None:
3112                super().__init__()
3113                self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
3114                self.param = torch.nn.Parameter(torch.rand(2, 3))
3115
3116            def forward(self, x):
3117                return self.conv(torch.cat([self.param, x]))
3118
3119        class B(torch.nn.Module):
3120            def __init__(self) -> None:
3121                super().__init__()
3122                self.linear = torch.nn.Linear(100, 200)
3123                self.buf = torch.nn.Buffer(torch.randn(2, 3))
3124                self.net_c = C()
3125
3126            def forward(self, x):
3127                return self.linear(torch.cat([self.buf, self.net_c(x)]))
3128
3129        class A(torch.nn.Module):
3130            def __init__(self) -> None:
3131                super().__init__()
3132                self.net_b = B()
3133                self.param = torch.nn.Parameter(torch.rand(2, 3))
3134
3135            def forward(self, x):
3136                return self.net_b(x) + self.param
3137
3138        a = symbolic_trace(A())
3139
3140        a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
3141
3142        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
3143        with a.graph.inserting_before(conv):
3144            with warnings.catch_warnings(record=True) as w:
3145                dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
3146                                              args=conv.args)
3147                self.assertEqual(len(w), 0)
3148
3149        conv.replace_all_uses_with(dropout)
3150        a.graph.erase_node(conv)
3151        a.recompile()
3152
3153        def module_exists(gm: GraphModule, path: str) -> bool:
3154            return any(path == name for name, _ in gm.named_modules())
3155
3156        def parameter_exists(gm: GraphModule, path: str) -> bool:
3157            return (any(path == name for name, _ in gm.named_parameters())
3158                    and any(path == name for name in gm.state_dict().keys()))
3159
3160        def buffer_exists(gm: GraphModule, path: str) -> bool:
3161            return (any(path == name for name, _ in gm.named_buffers())
3162                    and any(path == name for name in gm.state_dict().keys()))
3163
3164        # Test that we added the "dropout" submodule
3165        self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
3166
3167        # Test `get_submodule` with an added submodule
3168        self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
3169
3170        # Test that the "conv" submodule is still there
3171        self.assertTrue(module_exists(a, "net_b.net_c.conv"))
3172
3173        # Test `get_submodule` with an original module
3174        self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
3175
3176        # Test that the "conv" node is NOT still there
3177        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
3178        self.assertEqual(conv, [])
3179
3180        a.delete_submodule("net_b.net_c.conv")
3181
3182        # Test that the "conv" submodule is now gone
3183        self.assertFalse(module_exists(a, "net_b.net_c.conv"))
3184
3185        # Test `get_submodule` with a deleted submodule
3186        with self.assertRaisesRegex(AttributeError, "has no attribute "
3187                                    "`conv`"):
3188            self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
3189
3190        # Test `get_attr` warnings
3191        cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
3192
3193        with a.graph.inserting_before(cat):
3194
3195            with warnings.catch_warnings(record=True) as w:
3196                param = a.graph.get_attr(qualified_name="net_b.net_c.param")
3197                self.assertEqual(len(w), 0)
3198
3199            with self.assertWarnsRegex(UserWarning, "Attempted to "
3200                                       "insert a get_attr Node with no "
3201                                       "underlying reference in the "
3202                                       "owning GraphModule"):
3203                bad_param = a.graph.get_attr(qualified_name="net_b.param")
3204                a.graph.erase_node(bad_param)
3205
3206        cat.args = (*cat.args, param)
3207
3208        a.recompile()
3209
3210        a.graph.lint()
3211
3212        # Test `get_parameter`
3213        a.get_parameter("net_b.net_c.param")
3214        with self.assertRaisesRegex(AttributeError, "is not an "
3215                                    "nn.Parameter"):
3216            a.get_parameter("net_b.buf")
3217        with self.assertRaisesRegex(AttributeError, "has no attribute "
3218                                    "`param`"):
3219            a.get_parameter("net_b.param")
3220
3221        # Test `get_buffer`
3222        a.get_buffer("net_b.buf")
3223        with self.assertRaisesRegex(AttributeError, "is not a "
3224                                    "buffer"):
3225            a.get_buffer("net_b.net_c.param")
3226        with self.assertRaisesRegex(AttributeError, "has no attribute "
3227                                    "`buf`"):
3228            a.get_buffer("net_b.net_c.buf")
3229
3230        # Test non-nested attributes
3231        a.get_submodule("")
3232        a.get_parameter("param")
3233
3234        # Insert some unused submodules
3235        a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
3236        a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
3237        a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
3238        a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
3239
3240        # Garbage collection
3241        a.delete_all_unused_submodules()
3242
3243        # Test that all the unused submodules are gone
3244        self.assertFalse(module_exists(a, "net_b.embedding"))
3245        self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
3246        self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
3247        self.assertFalse(module_exists(a, "batch_norm_2d"))
3248
3249        # Test that we didn't delete any unused Parameters or buffers
3250        self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
3251        self.assertTrue(buffer_exists(a, "net_b.buf"))
3252
3253        a.graph.lint()
3254
3255    def test_delete_unused_submodules_leaf(self):
3256        class SubModule(torch.nn.Module):
3257            def __init__(self) -> None:
3258                super().__init__()
3259                self.linear = torch.nn.Linear(10, 10)
3260                self.relu = torch.nn.ReLU()
3261
3262            def forward(self, x):
3263                x = self.linear(x)
3264                x = self.relu(x)
3265                return x
3266
3267        class Model(torch.nn.Module):
3268            def __init__(self) -> None:
3269                super().__init__()
3270                self.submod = SubModule()
3271
3272            def forward(self, x):
3273                x = self.submod(x)
3274                return x
3275
3276        model = Model()
3277
3278        class MyCustomTracer(torch.fx.Tracer):
3279            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
3280                return module_qualified_name == "submod"
3281
3282        inputs = torch.randn(1, 10)
3283        traced_graph = MyCustomTracer().trace(model)
3284        gm2 = torch.fx.GraphModule(model, traced_graph)
3285        gm2.delete_all_unused_submodules()
3286        torch.testing.assert_close(gm2(inputs), model(inputs))
3287
3288    def test_fx_stateless(self):
3289        class MockModule(torch.nn.Module):
3290            def __init__(self) -> None:
3291                super().__init__()
3292                self.l1 = torch.nn.Linear(1, 1)
3293                self.buffer = torch.nn.Buffer(torch.ones(1))
3294
3295            def forward(self, x):
3296                return self.l1(x) + self.buffer
3297
3298        module = MockModule()
3299        x = torch.rand((1, 1))
3300        weight = torch.tensor([[1.0]], requires_grad=True)
3301        bias = torch.tensor([0.0], requires_grad=True)
3302        buffer = torch.tensor([0.0])
3303        parameters = {'l1.weight': weight,
3304                      'l1.bias': bias,
3305                      'buffer': buffer}
3306        fx_module = torch.fx.symbolic_trace(module)
3307        res = torch.func.functional_call(fx_module, parameters, x)
3308        res.backward()
3309        self.assertIsNotNone(weight.grad)
3310        self.assertIsNotNone(bias.grad)
3311        self.assertIsNone(buffer.grad)
3312        # Gradient was not calculated for the module stated and buffers
3313        self.assertIsNone(module.l1.weight.grad)
3314        self.assertIsNone(module.l1.bias.grad)
3315        self.assertIsNone(module.buffer.grad)
3316
3317    def test_tracing_graphmodules_as_leaf_submodules(self):
3318        class A(torch.nn.Module):
3319            def forward(self, t):
3320                return t + t
3321
3322        class B(torch.nn.Module):
3323            def __init__(self) -> None:
3324                super(type(self), self).__init__()
3325                self.calling = False
3326                self.called = False
3327
3328            def forward(self, t):
3329                if self.calling:
3330                    return t - t
3331                else:
3332                    return t + t
3333
3334            def __call__(self, *args):
3335                self.called = True
3336                self.calling = True
3337                return super(type(self), self).__call__(*args)
3338                self.calling = False
3339
3340        class M(torch.nn.Module):
3341            def __init__(self, a, b):
3342                super().__init__()
3343                self.a = a
3344                self.b = b
3345
3346            def forward(self, t):
3347                x = self.a(t)
3348                y = self.b(t)
3349                return x + y
3350
3351        class LeafTracer(Tracer):
3352            def is_leaf_module(self, module, name):
3353                return True
3354
3355        class LeafTracerNotB(Tracer):
3356            def is_leaf_module(self, module, name):
3357                return False if "b" in name else True
3358
3359        # Recompile calls added "for fun", since they
3360        # chain __call__ wrappers.
3361
3362        #
3363        # Test: B as a regular, non-leaf module
3364        #
3365        a = symbolic_trace(A())
3366        a.recompile()
3367        m = M(a, B())
3368        graph = LeafTracerNotB().trace(m)
3369        gm = GraphModule(m, graph)
3370        gm.recompile()
3371
3372        # Test graphmodule/submodule a is not inlined.
3373        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3374        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3375        self.assertTrue(len(match) == 1)
3376
3377        # Test submodule b is not treated as leaf.
3378        self.assertFalse(hasattr(gm, "b"))
3379
3380        # Test assert custom __call__ on submodule b was honored.
3381        match = [
3382            n
3383            for n in gm.graph.nodes
3384            if n.op == "call_function" and n.target == operator.sub
3385        ]
3386        self.assertTrue(len(match) == 1)
3387
3388        #
3389        # Test: B as a regular, leaf module
3390        # symbolic_trace should only patch torch.nn.Module.__call__,
3391        # which means B.__call__ should still execute
3392        #
3393        a = symbolic_trace(A())
3394        a.recompile()
3395        b = B()
3396        m = M(a, b)
3397        graph = LeafTracer().trace(m)
3398        gm = GraphModule(m, graph)
3399        gm.recompile()
3400
3401        # Test graphmodule/submodule a is not inlined.
3402        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3403        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3404        self.assertTrue(len(match) == 1)
3405
3406        # Test submodule b is leaf:
3407        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3408        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3409        self.assertTrue(len(match) == 1)
3410
3411        # Test b.__call__ was run
3412        self.assertTrue(b.called)
3413        self.assertTrue(gm.get_submodule("b").called)
3414
3415        #
3416        # Test: B as GraphModule leaf
3417        # __call__ not honored since symbolic_trace directly invokes forward()
3418        #
3419        a = symbolic_trace(A())
3420        a.recompile()
3421        b = symbolic_trace(B())
3422        b.recompile()
3423        m = M(a, b)
3424        graph = LeafTracer().trace(m)
3425        gm = GraphModule(m, graph)
3426        gm.recompile()
3427
3428        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3429        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3430        self.assertTrue(len(match) == 1)
3431
3432        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3433        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3434        self.assertTrue(len(match) == 1)
3435
3436    def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
3437        class MyModule(torch.nn.Module):
3438            def __init__(self) -> None:
3439                super().__init__()
3440                self.my_buff = torch.nn.Buffer(torch.rand(3, 4))
3441                self.register_parameter(
3442                    "my_param", torch.nn.Parameter(torch.rand(3, 4))
3443                )
3444
3445            def forward(self, x):
3446                return x + self.my_buff + self.my_param
3447
3448        mod = MyModule()
3449        mod_traced = symbolic_trace(mod)
3450
3451        # Create new GraphModule based on original, either w/ dict or root module.
3452        orig_buff = mod_traced.get_buffer("my_buff")
3453        orig_param = mod_traced.get_parameter("my_param")
3454        mod_traced_new = GraphModule(
3455            {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
3456            mod_traced.graph,
3457        )
3458
3459        # Check that both my_buff and my_param are found and the same.
3460        try:
3461            new_buff = mod_traced_new.get_buffer("my_buff")
3462        except Exception:
3463            self.fail("Did not find my_buff")
3464        self.assertEqual(orig_buff, new_buff)
3465
3466        try:
3467            new_param = mod_traced_new.get_parameter("my_param")
3468        except Exception:
3469            self.fail("Did not find my_param")
3470        self.assertEqual(orig_param, new_param)
3471
3472        x = torch.rand(3, 4)
3473        orig_out = mod_traced(x)
3474        submodules_out = mod_traced_new(x)
3475
3476        self.assertEqual(orig_out, submodules_out)
3477
3478    def test_graph_module_init_buffer_param_copied_dict_init(self):
3479        self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
3480
3481    def test_graph_module_init_buffer_param_copied_mod_init(self):
3482        self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
3483
3484    def test_annotations_with_no_forward_references(self):
3485        class A:
3486            def __call__(self, x: torch.Tensor):
3487                return torch.add(x, x)
3488
3489        class M(torch.nn.Module):
3490            def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
3491                return a(x)
3492
3493        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3494
3495    def test_annotations_with_forward_references(self):
3496        class A:
3497            def __call__(self, x: torch.Tensor):
3498                return torch.add(x, x)
3499
3500        class M(torch.nn.Module):
3501            def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
3502                return a(x)
3503
3504        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3505
3506    def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
3507        class A:
3508            def __call__(self, x: torch.Tensor):
3509                return torch.add(x, x)
3510
3511        class M(torch.nn.Module):
3512            def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
3513                return a(x[0])
3514
3515        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3516
3517    def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
3518        class A:
3519            def __call__(self, x: torch.Tensor):
3520                return torch.add(x, x)
3521
3522        class M(torch.nn.Module):
3523            def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
3524                return a(x)[0]
3525
3526        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3527
3528    @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
3529                     "`annotations` is not defined in Python <3.7")
3530    def test_annotation_with_future(self):
3531        try:
3532            import fx.test_future    # noqa: F401
3533        finally:
3534            del sys.modules["__future__"]
3535
3536    @unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
3537    def test_annotations_empty_tuple(self):
3538        class Foo(torch.nn.Module):
3539            def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3540                return "foo"
3541
3542        traced = torch.fx.symbolic_trace(Foo())
3543
3544        x = ()
3545        y = ("bar", ())
3546
3547        traced(x, y)
3548
3549        FileCheck().check("_Tuple[()]")   \
3550                   .check("typing_Tuple[str,typing_Tuple[()]]") \
3551                   .run(traced.code)
3552
3553        scripted = torch.jit.script(traced)
3554
3555        scripted(x, y)
3556
3557        FileCheck().check("Tuple[()]")   \
3558            .check("Tuple[str, Tuple[()]]")    \
3559            .run(scripted.code)
3560
3561    @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
3562    @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
3563    def test_assert(self):
3564        def f(x):
3565            assert x > 1
3566            return x + 1
3567        try:
3568            torch.fx.proxy.TracerBase.trace_asserts = True
3569            traced = symbolic_trace(f)
3570        finally:
3571            torch.fx.proxy.TracerBase.trace_asserts = False
3572
3573        self.assertEqual(f(2), traced(2))
3574        with self.assertRaises(AssertionError):
3575            traced(0)
3576
3577    def test_pytree(self):
3578        # Used to test that you can use your own placeholder class
3579        class PHTest(PHBase):
3580            pass
3581
3582        def f_sum(x):
3583            return sum(x)
3584
3585        def f_sum_dict(x):
3586            out = 0
3587            for v in x.values():
3588                out += v
3589            return out
3590
3591        def f_dict_list_map(x):
3592            new_dict = {}
3593            for k, v in x.items():
3594                new_dict[k] = [i + 1 for i in v]
3595            return new_dict
3596
3597        def f_dict_add(x):
3598            return x['a'] + sum(x['z'])
3599
3600        def f_namedtuple_add(x):
3601            return x.x + x.y
3602
3603        pytree.register_pytree_node(
3604            Foo,
3605            lambda x: ([x.a, x.b], None),
3606            lambda x, _: Foo(x[0], x[1]),
3607        )
3608        fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
3609
3610        def f_custom(x):
3611            return x.a + x.b
3612
3613        def f_custom_dict(x):
3614            return f_sum_dict(x.a) + x.b
3615
3616        def f_return_custom(x):
3617            return Foo(x.b, x.a)
3618
3619        tests = [
3620            (f_sum, [PH, PH, PH]),
3621            (f_sum, []),
3622            (f_sum, [PHTest(), PHTest(), PHTest()]),
3623            (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
3624            (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
3625            (f_dict_list_map, {5: (PH, PH, PH)}),
3626            (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
3627            (f_dict_add, {'a': PH, 'z': []}),
3628            (f_custom, Foo(PH, PH)),
3629            (f_custom, Foo(PH, 3)),
3630            (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
3631            # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
3632            (f_namedtuple_add, Point(PH, PH)),
3633        ]
3634
3635        def verify_pytree(f, inp):
3636            val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
3637            num_flat_args = len(pytree.tree_leaves(inp))
3638            orig_out = f(val)
3639            nf = symbolic_trace(f, concrete_args={'x': inp})
3640            self.assertEqual(nf(val), orig_out)
3641
3642            bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3643            bare_fx.graph.set_codegen(CodeGen())
3644            bare_fx.recompile()
3645            self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
3646
3647            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3648            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3649
3650            nf = symbolic_trace(nf)
3651            self.assertEqual(nf(val), orig_out)
3652            assert "tree_flatten_spec" not in nf.code
3653            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1
3654
3655            nf = symbolic_trace(nf, concrete_args={'x': inp})
3656            self.assertEqual(nf(val), orig_out)
3657            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3658            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3659
3660            pickled = pickle.dumps(nf)
3661            nf = pickle.loads(pickled)
3662            self.assertEqual(nf(val), orig_out)
3663
3664        for f, inp in tests:
3665            verify_pytree(f, inp)
3666
3667    def test_pytree_concrete(self):
3668        def f(b, a):
3669            if b:
3670                return a['a']
3671            else:
3672                return a['z']
3673
3674        inp = {'a': {'a': PH, 'z': PH}, 'b': True}
3675        nf = symbolic_trace(f, concrete_args=inp)
3676        val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3677        self.assertEqual(nf(**val), f(**val))
3678
3679        nf = symbolic_trace(nf)
3680        self.assertEqual(nf(**val), f(**val))
3681
3682    def test_metadata_on_ph(self):
3683        def f_sum(a: int, b: int) -> int:
3684            return a + b
3685
3686        # Due to unflattening of dict, the batch argument
3687        # will be split into two separate nodes with the names
3688        # "batch_1" and "batch_2", referring to the keys
3689        # "f1" and "f2" respectively in the dict.
3690        def f_dict(a: Dict[str, str]) -> bool:
3691            return a["f1"] == a["f2"]
3692
3693        def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]):
3694            for node in gm.graph.nodes:
3695                if node.op == "placeholder":
3696                    self.assertTrue(node.name in arg_names)
3697                    self.assertTrue(node.ph_key in metadata)
3698
3699        verify_metadata(
3700            gm=symbolic_trace(
3701                f_sum,
3702                concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")}
3703            ),
3704            arg_names=["a_1", "b_1"],
3705            metadata=["a", "b"]
3706        )
3707        verify_metadata(
3708            gm=symbolic_trace(
3709                f_dict,
3710                concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}}
3711            ),
3712            arg_names=["a_1", "a_2"],
3713            metadata=["f1", "f2"]
3714        )
3715
3716        # Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag)
3717        class TaggingTracer(Tracer):
3718            def create_node(self, kind : str, target : Union[str, Callable],
3719                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
3720                            type_expr : Optional[Any] = None) -> Node:
3721                n = super().create_node(kind, target, args, kwargs, name)
3722                n.tag = "foo"
3723                return n
3724
3725        class PHWithTag(PHBase):
3726            def __init__(self, tag: str):
3727                super().__init__()
3728
3729                self.tag = tag
3730
3731        g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")})
3732        for n in g.nodes:
3733            self.assertTrue(hasattr(n, "tag"))
3734            # Ensure that tag is still "foo" and not "bar" (from PHWithTag)
3735            self.assertEqual(n.tag, "foo")
3736
3737    def test_custom_codegen(self):
3738        class ListCodeGen(CodeGen):
3739            def gen_fn_def(self, free_vars, maybe_return_annotation):
3740                lst_unpack = f"""
3741def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3742    {', '.join(free_vars)} = args_list"""
3743                return lst_unpack
3744
3745            def additional_globals(self):
3746                return [('List', typing.List)]
3747
3748            def process_inputs(self, *inputs):
3749                assert len(inputs) == 1
3750                return inputs[0]
3751
3752        def f(a, b):
3753            return a + b
3754
3755        nf = symbolic_trace(f)
3756        vals = [torch.randn(3), torch.randn(3)]
3757        self.assertEqual(nf(*vals), f(*vals))
3758
3759        nf.graph.set_codegen(ListCodeGen())
3760        nf.recompile()
3761
3762        bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3763        bare_fx.graph.set_codegen(CodeGen())
3764        bare_fx.recompile()
3765
3766        self.assertEqual(nf(vals), f(*vals))
3767        self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
3768
3769        ts_f = torch.jit.script(nf)
3770        self.assertEqual(nf(vals), ts_f(vals))
3771
3772    def test_custom_codegen_with_transformer(self):
3773        class ListCodeGen(CodeGen):
3774            def gen_fn_def(self, free_vars, maybe_return_annotation):
3775                lst_unpack = f"""
3776def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3777    {', '.join(free_vars)} = args_list"""
3778                return lst_unpack
3779
3780            def additional_globals(self):
3781                return [('List', typing.List)]
3782
3783            def process_inputs(self, *inputs):
3784                assert len(inputs) == 1
3785                return inputs[0]
3786
3787        def f(a, b):
3788            return a + b
3789
3790        nf = symbolic_trace(f)
3791        vals = [torch.randn(3), torch.randn(3)]
3792        self.assertEqual(nf(*vals), f(*vals))
3793
3794        nf.graph.set_codegen(ListCodeGen())
3795        nf.recompile()
3796        self.assertEqual(nf(vals), f(*vals))
3797
3798        transformed_gm = Transformer(nf).transform()
3799        self.assertEqual(nf(vals), transformed_gm(vals))
3800
3801    def test_interpreter_with_codegen(self):
3802        class ListCodeGen(CodeGen):
3803            def gen_fn_def(self, free_vars, maybe_return_annotation):
3804                lst_unpack = f"""
3805def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3806    {', '.join(free_vars)} = args_list"""
3807                return lst_unpack
3808
3809            def additional_globals(self):
3810                return [('List', typing.List)]
3811
3812            def process_inputs(self, *inputs):
3813                assert len(inputs) == 1
3814                return inputs[0]
3815
3816            def generate_output(self, output_args):
3817                return f'return list({repr(output_args)})'
3818
3819            def process_outputs(self, outputs):
3820                return list(outputs)
3821
3822        def f(a, b):
3823            a = a + b
3824            b = a + b
3825            return a, b
3826
3827        nf = symbolic_trace(f)
3828        vals = [torch.randn(3), torch.randn(3)]
3829        nf.graph.set_codegen(ListCodeGen())
3830        nf.recompile()
3831        self.assertEqual(Interpreter(nf).run(vals), nf(vals))
3832
3833    def test_imul_code_print(self):
3834        graph = torch.fx.Graph()
3835        a = graph.placeholder("a")
3836        b = graph.placeholder("b")
3837        graph.call_function(operator.imul, (a, b), {})
3838        graph.output(a)
3839        gm = torch.fx.GraphModule({}, graph)
3840        gm.recompile()
3841        self.assertEqual(gm(2, 3), 6)
3842        self.assertIn("a *= b", gm.code)
3843
3844    def test_deepcopy_tracer(self):
3845        def fn(x, y):
3846            return (x + y).relu().sin()
3847
3848        tracer = Tracer()
3849        tracer_before = copy.deepcopy(tracer)
3850        tracer.trace(fn)
3851        tracer_after = copy.deepcopy(tracer)
3852
3853        self.assertEqual(str(tracer.graph), str(tracer_after.graph))
3854        self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
3855
3856    def test_deepcopy_graphmodule(self):
3857        m = symbolic_trace(SimpleTest())
3858        m.meta['hello'] = 'world'
3859        copy_m = copy.deepcopy(m)
3860        self.assertEqual(copy_m.meta['hello'], 'world')
3861
3862    def test_deepcopy_no_recursion(self):
3863        m = symbolic_trace(SimpleTest())
3864        m.meta['hello'] = m  # circular reference
3865        copy_m = copy.deepcopy(m)  # finishes
3866        self.assertEqual(id(copy_m), id(copy_m.meta['hello']))
3867
3868    def test_enum(self):
3869        from enum import Enum
3870
3871        class Foo(Enum):
3872            A = 1
3873            B = 2
3874
3875        def leaf_fn(arr, enum_val):
3876            # Use the raw enum.
3877            arr.append(enum_val)
3878            return arr[-1].value
3879
3880        def foo(x):
3881            # Pass the enum as argument.
3882            return leaf_fn(x, Foo.A)
3883
3884        traced = torch.fx.symbolic_trace(foo)
3885        self.assertEqual(foo([]), traced([]))
3886
3887    def test_insert_arg(self):
3888        m = symbolic_trace(SimpleTest())
3889        m.buf = torch.nn.Buffer(torch.tensor(0))
3890        output_node = next(iter(reversed(m.graph.nodes)))
3891        with m.graph.inserting_before(output_node):
3892            a = m.graph.get_attr("buf")
3893        r = len(output_node.args)
3894        output_node.insert_arg(0, a)
3895        self.assertEqual(len(output_node.args), r + 1)
3896        self.assertEqual(len(a.users), 1)
3897        self.assertIs(output_node.args[0], a)
3898        self.assertIs(next(iter(a.users.keys())), output_node)
3899        output_node.insert_arg(2, a)
3900        self.assertEqual(len(output_node.args), r + 2)
3901        self.assertEqual(len(a.users), 1)
3902        self.assertIs(output_node.args[2], a)
3903        self.assertIs(next(iter(a.users.keys())), output_node)
3904        m.graph.lint()
3905
3906    def test_delete_unused_values(self):
3907        from torch.fx.experimental.proxy_tensor import make_fx
3908
3909        # disable mutable checking temporarily
3910        orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3911        torch.fx.proxy.TracerBase.check_mutable_operations = False
3912
3913        def fn(a, b, c, d):
3914            x = a + b
3915            y = c + d
3916            y.copy_(x)
3917            x = torch.relu(x)
3918            return x
3919
3920        a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4))
3921        fx_fn = make_fx(fn)(a, b, c, d)
3922        print(fx_fn)
3923
3924        fx_fn.graph.eliminate_dead_code()
3925        py_code = fx_fn.recompile()
3926        self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src)
3927        self.assertTrue("copy_ = None" in py_code.src)
3928
3929        # recorver mutable checking flag
3930        torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
3931
3932def run_getitem_target():
3933    from torch.fx._symbolic_trace import _wrapped_methods_to_patch
3934    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
3935    try:
3936        TestFX().getitem_inner()
3937    finally:
3938        _wrapped_methods_to_patch.pop()
3939
3940
3941class TestOperatorSignatures(JitTestCase):
3942    def setUp(self):
3943        # Checking for mutable operations whil tracing is feature flagged
3944        # Enable it in testing but not by default
3945        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3946        torch.fx.proxy.TracerBase.check_mutable_operations = True
3947
3948    def tearDown(self):
3949        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3950
3951    @onlyCPU
3952    @ops(op_db, allowed_dtypes=(torch.float,))
3953    def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
3954        if not isinstance(op.op, types.BuiltinFunctionType):
3955            raise unittest.SkipTest("This path doesn't work on Python functions")
3956        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
3957        schemas = get_signature_for_torch_op(op.op)
3958        if not schemas:
3959            raise RuntimeError('No Schemas Returned')
3960        for sample_input in sample_inputs_itr:
3961            # Iterate through overloads until we hit a match. If we exit this
3962            # loop via `else`, we haven't found a match
3963            for schema in schemas:
3964                try:
3965                    bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
3966                    bound_args.apply_defaults()
3967                    op(*bound_args.args, **bound_args.kwargs)
3968                    break
3969                except TypeError as e:
3970                    pass
3971            else:
3972                raise RuntimeError(f'Did not match any schemas for op {op.name}!')
3973
3974
3975class TestFXAPIBackwardCompatibility(JitTestCase):
3976    def setUp(self):
3977        super().setUp()
3978        self.maxDiff = None
3979
3980        # Checking for mutable operations whil tracing is feature flagged
3981        # Enable it in testing but not by default
3982        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3983        torch.fx.proxy.TracerBase.check_mutable_operations = True
3984
3985    def tearDown(self):
3986        super().tearDown()
3987        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3988
3989
3990    def _fn_to_stable_annotation_str(self, obj):
3991        """
3992        Unfortunately we have to serialize function signatures manually since
3993        serialization for `inspect.Signature` objects is not stable across
3994        python versions
3995        """
3996        fn_name = torch.typename(obj)
3997
3998        signature = inspect.signature(obj)
3999
4000        sig_str = f'{fn_name}{signature}'
4001
4002        arg_strs = []
4003        for k, v in signature.parameters.items():
4004            maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
4005                if v.annotation is not inspect.Signature.empty else ''
4006
4007            def default_val_str(val):
4008                if isinstance(val, (tuple, list)):
4009                    str_pieces = ['(' if isinstance(val, tuple) else '[']
4010                    str_pieces.append(', '.join(default_val_str(v) for v in val))
4011                    if isinstance(val, tuple) and len(str_pieces) == 2:
4012                        str_pieces.append(',')
4013                    str_pieces.append(')' if isinstance(val, tuple) else ']')
4014                    return ''.join(str_pieces)
4015
4016                # Need to fix up some default value strings.
4017                # First case: modules. Default module `repr` contains the FS path of the module.
4018                # Don't leak that
4019                if isinstance(val, types.ModuleType):
4020                    return f'<module {val.__name__}>'
4021
4022                # Second case: callables. Callables (such as lambdas) encode their address in
4023                # their string repr. Don't do that
4024                if callable(val):
4025                    return f'<function {val.__name__}>'
4026
4027                return str(val)
4028
4029            if v.default is not inspect.Signature.empty:
4030                default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
4031                maybe_default = f' = {default_val_str}'
4032            else:
4033                maybe_default = ''
4034            maybe_stars = ''
4035            if v.kind == inspect.Parameter.VAR_POSITIONAL:
4036                maybe_stars = '*'
4037            elif v.kind == inspect.Parameter.VAR_KEYWORD:
4038                maybe_stars = '**'
4039            arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
4040
4041        return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
4042            if signature.return_annotation is not inspect.Signature.empty else ''
4043
4044        return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
4045
4046    def _annotation_type_to_stable_str(self, t, sig_str):
4047        if t is inspect.Signature.empty:
4048            return ''
4049
4050        # Forward ref
4051        if isinstance(t, str):
4052            return f"'{t}'"
4053        if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
4054            return t.__forward_arg__
4055        if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
4056            return t.__forward_arg__
4057
4058        trivial_mappings = {
4059            str : 'str',
4060            int : 'int',
4061            float: 'float',
4062            bool: 'bool',
4063            torch.dtype: 'torch.dtype',
4064            torch.Tensor: 'torch.Tensor',
4065            torch.device: 'torch.device',
4066            torch.memory_format: 'torch.memory_format',
4067            slice: 'slice',
4068            torch.nn.Module: 'torch.nn.modules.module.Module',
4069            torch.fx.Graph : 'torch.fx.graph.Graph',
4070            torch.fx.Node : 'torch.fx.node.Node',
4071            torch.fx.Proxy : 'torch.fx.proxy.Proxy',
4072            torch.fx.node.Target : 'torch.fx.node.Target',
4073            torch.fx.node.Argument : 'torch.fx.node.Argument',
4074            torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
4075            torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
4076            torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
4077            Ellipsis : '...',
4078            typing.Any: 'Any',
4079            type(None): 'NoneType',
4080            None: 'None',
4081            typing.Iterator: 'Iterator',
4082        }
4083
4084        mapping = trivial_mappings.get(t, None)
4085        if mapping:
4086            return mapping
4087
4088        # Handle types with contained types
4089        contained = getattr(t, '__args__', None) or []
4090
4091        # Callables contain a bare List for arguments
4092        contained = t if isinstance(t, list) else contained
4093
4094        # Python 3.8 puts type vars into __args__ for unbound types such as Dict
4095        if all(isinstance(ct, typing.TypeVar) for ct in contained):
4096            contained = []
4097
4098        contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
4099        contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
4100
4101
4102        origin = getattr(t, '__origin__', None)
4103        if origin is None:
4104            # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
4105            origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
4106
4107        if origin in {tuple, typing.Tuple}:
4108            return f'Tuple{contained_type_str}'
4109        if origin in {typing.Union}:
4110            # Annoying hack to detect Optional
4111            if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
4112                not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
4113                return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
4114            return f'Union{contained_type_str}'
4115        if origin in {dict, typing.Dict}:
4116            return f'Dict{contained_type_str}'
4117        if origin in {list, typing.List}:
4118            return f'List{contained_type_str}'
4119        if origin in {type, typing.Type}:
4120            return f'Type{contained_type_str}'
4121        if isinstance(t, typing.Callable):
4122            if len(contained) > 0 and contained[0] is not Ellipsis:
4123                return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
4124            else:
4125                return f'Callable{contained_type_str}'
4126
4127        raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
4128                           f'Please add support for this type and confirm with the '
4129                           f'FX team that your signature change is valid.')
4130
4131
4132    def test_function_back_compat(self):
4133        """
4134        Test backward compatibility for function signatures with
4135        @compatibility(is_backward_compatible=True). Currently this checks for
4136        exact signature matches, which may lead to false positives. If this
4137        becomes too annoying, we can refine this check to actually parse out
4138        the saved schema strings and check if the change is truly backward-
4139        incompatible.
4140        """
4141        signature_strs = []
4142
4143        for obj in _BACK_COMPAT_OBJECTS:
4144            if not isinstance(obj, type):
4145                signature_strs.append(self._fn_to_stable_annotation_str(obj))
4146
4147        signature_strs.sort()
4148
4149        try:
4150            self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures')
4151        except AssertionError as e:
4152            msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
4153                  f"as backwards-compatible has experienced a signature change. See the " \
4154                  f"above exception context for more information. If this change was " \
4155                  f"unintended, please revert it. If it was intended, check with the FX " \
4156                  f"team to ensure that the proper deprecation protocols have been followed " \
4157                  f"and subsequently --accept the change."
4158            raise AssertionError(msg)  # noqa: B904
4159
4160    def test_class_member_back_compat(self):
4161        """
4162        Test backward compatibility for members of classes with
4163        @compatibility(is_backward_compatible=True). Currently this checks for
4164        exact matches on the publicly visible members of the class.
4165        """
4166        class_method_strs = []
4167
4168        for obj in _BACK_COMPAT_OBJECTS:
4169            if isinstance(obj, type):
4170                public_members = [name for name in obj.__dict__ if not name.startswith('_')]
4171                class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
4172
4173        class_method_strs.sort()
4174
4175        try:
4176            self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
4177        except AssertionError as e:
4178            msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
4179                  f"as backwards-compatible has experienced change in its public members. See the " \
4180                  f"above exception context for more information. If this change was " \
4181                  f"unintended, please revert it. If it was intended, check with the FX " \
4182                  f"team to ensure that the proper deprecation protocols have been followed " \
4183                  f"and subsequently --accept the change."
4184            raise AssertionError(msg) from e
4185
4186    def test_public_api_surface(self):
4187        non_back_compat_objects = {}
4188
4189        def check_symbols_have_bc_designation(m, seen):
4190            if not m.__name__.startswith('torch.fx'):
4191                return
4192            if m.__name__.startswith('torch.fx.experimental'):
4193                return
4194            # It's really common for inner functions to point to random modules
4195            # - make sure we don't recurse into modules we've already checked.
4196            seen.add(m.__name__)
4197            for k, v in m.__dict__.items():
4198                if hasattr(v, '__name__') and v.__name__ in seen:
4199                    continue
4200                if v is m:
4201                    continue
4202                if k.startswith('_'):
4203                    continue
4204                if isinstance(v, types.ModuleType):
4205                    check_symbols_have_bc_designation(v, seen)
4206                elif isinstance(v, (type, types.FunctionType)):
4207                    if v not in _MARKED_WITH_COMPATIBILITY:
4208                        non_back_compat_objects.setdefault(v)
4209
4210        check_symbols_have_bc_designation(torch.fx, set())
4211        check_symbols_have_bc_designation(torch.fx.passes, set())
4212
4213        non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
4214        # Only want objects in torch.fx
4215        non_back_compat_strs = [
4216            s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
4217        # Only want objects in public namespaces
4218        non_back_compat_strs = [
4219            s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
4220        non_back_compat_strs.sort()
4221
4222        if len(non_back_compat_strs) != 0:
4223            raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
4224                                 f"backwards-compatibility classification! Please decorate these "
4225                                 f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
4226                                 f"BC guarantees.")
4227
4228    def test_adding_side_effect_function(self):
4229        class TestModule(torch.nn.Module):
4230            def forward(self, x):
4231                side_effect_func(x)
4232                return x
4233
4234        gm = torch.fx.symbolic_trace(TestModule())
4235        self.assertEqual(len(gm.graph.nodes), 3)
4236        gm.graph.eliminate_dead_code()
4237        gm.recompile()
4238        self.assertEqual(len(gm.graph.nodes), 3)
4239        found = False
4240        for node in gm.graph.nodes:
4241            if node.op == 'call_function' and node.target == side_effect_func:
4242                found = True
4243        self.assertTrue(found)
4244
4245    def test_preserve_unused_attr_after_unpickle(self):
4246        gm = torch.fx.symbolic_trace(Add())
4247        gm.add_submodule("foo", Add())
4248        gm.dummy_buffer = torch.nn.Buffer(torch.empty(1))
4249        gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1)))
4250        b = io.BytesIO()
4251        torch.save(gm, b)
4252        b.seek(0)
4253        # weights_only=False as this loads a GraphModule
4254        # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
4255        reload_gm = torch.load(b, weights_only=False)
4256        self.assertTrue(hasattr(reload_gm, "foo"))
4257        self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
4258        self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
4259
4260# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454
4261@unittest.skipIf(
4262    sys.version_info >= (3, 12), "Failing on python 3.12+"
4263)
4264class TestFunctionalTracing(JitTestCase):
4265    def setUp(self):
4266        super().setUp()
4267        # Checking for mutable operations whil tracing is feature flagged
4268        # Enable it in testing but not by default
4269        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4270        torch.fx.proxy.TracerBase.check_mutable_operations = True
4271
4272    def tearDown(self):
4273        super().tearDown()
4274        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4275
4276    IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
4277                    "has_torch_function_variadic", "handle_torch_function",
4278                    "boolean_dispatch")
4279    TO_PATCH = {"has_torch_function": None,
4280                "has_torch_function_unary": None,
4281                "has_torch_function_variadic": None}
4282
4283    BUILT_IN_FUNC = (AssertionError, "")
4284    PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
4285    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4286    LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
4287    ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
4288    CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
4289    INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
4290    MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
4291
4292    UNTRACEABLE_FUNCTIONALS = {
4293        "adaptive_avg_pool1d": BUILT_IN_FUNC,
4294        "avg_pool1d": BUILT_IN_FUNC,
4295        "avg_pool2d": BUILT_IN_FUNC,
4296        "avg_pool3d": BUILT_IN_FUNC,
4297        "bilinear": BUILT_IN_FUNC,
4298        "celu_": BUILT_IN_FUNC,
4299        "channel_shuffle": BUILT_IN_FUNC,
4300        "native_channel_shuffle": BUILT_IN_FUNC,
4301        "conv1d": BUILT_IN_FUNC,
4302        "conv2d": BUILT_IN_FUNC,
4303        "conv3d": BUILT_IN_FUNC,
4304        "conv_tbc": BUILT_IN_FUNC,
4305        "conv_transpose1d": BUILT_IN_FUNC,
4306        "conv_transpose2d": BUILT_IN_FUNC,
4307        "conv_transpose3d": BUILT_IN_FUNC,
4308        "cosine_similarity": BUILT_IN_FUNC,
4309        "elu_": BUILT_IN_FUNC,
4310        "gelu": BUILT_IN_FUNC,
4311        "hardshrink": BUILT_IN_FUNC,
4312        "hardtanh_": BUILT_IN_FUNC,
4313        "leaky_relu_": BUILT_IN_FUNC,
4314        "linear": BUILT_IN_FUNC,
4315        "logsigmoid": BUILT_IN_FUNC,
4316        "one_hot": BUILT_IN_FUNC,
4317        "pad": ARG_TYPE_MISMATCH,
4318        "pairwise_distance": BUILT_IN_FUNC,
4319        "pdist": BUILT_IN_FUNC,
4320        "pixel_shuffle": BUILT_IN_FUNC,
4321        "pixel_unshuffle": BUILT_IN_FUNC,
4322        "prelu": BUILT_IN_FUNC,
4323        "relu_": BUILT_IN_FUNC,
4324        "rrelu_": BUILT_IN_FUNC,
4325        "selu_": BUILT_IN_FUNC,
4326        "scaled_dot_product_attention": BUILT_IN_FUNC,
4327        "softplus": BUILT_IN_FUNC,
4328        "softshrink": BUILT_IN_FUNC,
4329        "threshold_": BUILT_IN_FUNC,
4330
4331        "adaptive_avg_pool2d": LEN_ERROR,
4332        "adaptive_avg_pool3d": LEN_ERROR,
4333        "adaptive_max_pool2d_with_indices": LEN_ERROR,
4334        "adaptive_max_pool3d_with_indices": LEN_ERROR,
4335        "instance_norm": CONTROL_FLOW,
4336
4337        "adaptive_max_pool1d": PROXY_ITERABLE,
4338        "adaptive_max_pool2d": PROXY_ITERABLE,
4339        "adaptive_max_pool3d": PROXY_ITERABLE,
4340        "fractional_max_pool2d": PROXY_ITERABLE,
4341        "fractional_max_pool3d": PROXY_ITERABLE,
4342        "max_pool1d": PROXY_ITERABLE,
4343        "max_pool2d": PROXY_ITERABLE,
4344        "max_pool3d": PROXY_ITERABLE,
4345
4346        "lp_pool2d": PROXY_ITERATED,
4347        "lp_pool3d": PROXY_ITERATED,
4348        "max_unpool1d": PROXY_ITERATED,
4349        "max_unpool2d": PROXY_ITERATED,
4350        "max_unpool3d": PROXY_ITERATED,
4351        "fold": PROXY_ITERATED,
4352        "unfold": PROXY_ITERATED,
4353
4354        "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4355        "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4356        "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4357        "layer_norm": ARG_TYPE_MISMATCH,
4358        "rms_norm": ARG_TYPE_MISMATCH,
4359        "lp_pool1d": ARG_TYPE_MISMATCH,
4360
4361        "affine_grid": CONTROL_FLOW,
4362        "alpha_dropout": CONTROL_FLOW,
4363        "batch_norm": CONTROL_FLOW,
4364        "binary_cross_entropy": CONTROL_FLOW,
4365        "binary_cross_entropy_with_logits": CONTROL_FLOW,
4366        "celu": CONTROL_FLOW,
4367        "cosine_embedding_loss": CONTROL_FLOW,
4368        "cross_entropy": CONTROL_FLOW,
4369        "ctc_loss": CONTROL_FLOW,
4370        "dropout": CONTROL_FLOW,
4371        "dropout1d": CONTROL_FLOW,
4372        "dropout2d": CONTROL_FLOW,
4373        "dropout3d": CONTROL_FLOW,
4374        "elu": CONTROL_FLOW,
4375        "embedding": CONTROL_FLOW,
4376        "embedding_bag": CONTROL_FLOW,
4377        "feature_alpha_dropout": CONTROL_FLOW,
4378        "gaussian_nll_loss": CONTROL_FLOW,
4379        "glu": CONTROL_FLOW,
4380        "grid_sample": CONTROL_FLOW,
4381        "group_norm": CONTROL_FLOW,
4382        "gumbel_softmax": CONTROL_FLOW,
4383        "hardsigmoid": CONTROL_FLOW,
4384        "hardswish": CONTROL_FLOW,
4385        "hardtanh": CONTROL_FLOW,
4386        "hinge_embedding_loss": CONTROL_FLOW,
4387        "huber_loss": CONTROL_FLOW,
4388        "interpolate": CONTROL_FLOW,
4389        "kl_div": CONTROL_FLOW,
4390        "l1_loss": CONTROL_FLOW,
4391        "leaky_relu": CONTROL_FLOW,
4392        "local_response_norm": CONTROL_FLOW,
4393        "margin_ranking_loss": CONTROL_FLOW,
4394        "max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4395        "max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4396        "max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4397        "mse_loss": CONTROL_FLOW,
4398        "multi_head_attention_forward": CONTROL_FLOW,
4399        "multi_margin_loss": CONTROL_FLOW,
4400        "multilabel_margin_loss": CONTROL_FLOW,
4401        "multilabel_soft_margin_loss": CONTROL_FLOW,
4402        "nll_loss": CONTROL_FLOW,
4403        "poisson_nll_loss": CONTROL_FLOW,
4404        "relu": CONTROL_FLOW,
4405        "relu6": CONTROL_FLOW,
4406        "rrelu": CONTROL_FLOW,
4407        "selu": CONTROL_FLOW,
4408        "silu": CONTROL_FLOW,
4409        "mish": CONTROL_FLOW,
4410        "smooth_l1_loss": CONTROL_FLOW,
4411        "soft_margin_loss": CONTROL_FLOW,
4412        "threshold": CONTROL_FLOW,
4413        "triplet_margin_loss": CONTROL_FLOW,
4414        "triplet_margin_with_distance_loss": CONTROL_FLOW,
4415        "upsample": CONTROL_FLOW,
4416
4417        "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
4418        "upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
4419    }
4420
4421    # List of nn.functionals with Tensor inputs but not with type annotation
4422    FUNCTIONALS_WITHOUT_ANNOTATION = (
4423        "adaptive_max_pool1d",
4424        "adaptive_max_pool2d",
4425        "adaptive_max_pool3d",
4426        "fractional_max_pool2d",
4427        "fractional_max_pool3d",
4428        "max_pool1d",
4429        "max_pool2d",
4430        "max_pool3d",
4431        "gaussian_nll_loss",
4432        "upsample",
4433        "upsample_bilinear",
4434        "upsample_nearest",
4435    )
4436
4437    # Inconsistent behavior between Python 3.8 and other Python versions:
4438    # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
4439    # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
4440    #                 internal exception above
4441    # Use the following map to override the expected exception for Python 3.8
4442    UNTRACEABLE_FUNCTIONALS_PY38 = {
4443        "adaptive_max_pool1d": PROXY_ITERATED,
4444        "adaptive_max_pool2d": PROXY_ITERATED,
4445        "adaptive_max_pool3d": PROXY_ITERATED,
4446        "fractional_max_pool2d": PROXY_ITERATED,
4447        "fractional_max_pool3d": PROXY_ITERATED,
4448        "max_pool1d": PROXY_ITERATED,
4449        "max_pool2d": PROXY_ITERATED,
4450        "max_pool3d": PROXY_ITERATED,
4451
4452        "group_norm": CONTROL_FLOW
4453    }
4454
4455    @classmethod
4456    def _get_functional(cls):
4457        functional_list = []
4458        for f in dir(torch.nn.functional):
4459            if not f.islower():
4460                continue
4461            # Ignore internal functions
4462            if f.startswith('_'):
4463                continue
4464            # Ignore supporting functions
4465            if f in cls.IGNORE_FUNCS:
4466                continue
4467            fn = getattr(torch.nn.functional, f)
4468            # Ignore non-callable object like modules
4469            if not isinstance(fn, Callable):
4470                continue
4471            if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
4472                try:
4473                    sig = inspect.signature(fn)
4474                    has_tensor_arg = False
4475                    for param in sig.parameters.values():
4476                        if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
4477                            has_tensor_arg = True
4478                    if not has_tensor_arg:
4479                        continue
4480                # No signature or Object is not supported
4481                except ValueError:
4482                    pass
4483            functional_list.append((f, fn))
4484        return functional_list
4485
4486    @classmethod
4487    def generate_test_func(cls, func_name, fn):
4488
4489        def functional_test(self):
4490            if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
4491                    sys.version_info >= (3, 8) and sys.version_info < (3, 12):
4492                exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
4493                with self.assertRaisesRegex(exc, err):
4494                    symbolic_trace(fn)
4495            elif func_name in self.UNTRACEABLE_FUNCTIONALS:
4496                exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
4497                with self.assertRaisesRegex(exc, err):
4498                    symbolic_trace(fn)
4499            else:
4500                symbolic_trace(fn)
4501        return functional_test
4502
4503    @classmethod
4504    def generate_tests(cls):
4505        functional_list = cls._get_functional()
4506        for func_name, fn in functional_list:
4507            test_name = "test_nn_functional_" + func_name
4508            functional_test = cls.generate_test_func(func_name, fn)
4509            setattr(cls, test_name, functional_test)
4510
4511    @classmethod
4512    def setUpClass(cls):
4513
4514        def no(*args, **kwargs):
4515            return False
4516
4517        for name in cls.TO_PATCH.keys():
4518            cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
4519            setattr(torch.nn.functional, name, no)
4520
4521    @classmethod
4522    def tearDownClass(cls):
4523        for name in cls.TO_PATCH.keys():
4524            setattr(torch.nn.functional, name, cls.TO_PATCH[name])
4525
4526TestFunctionalTracing.generate_tests()
4527
4528
4529instantiate_device_type_tests(TestOperatorSignatures, globals())
4530
4531@skipIfTorchDynamo("too slow")
4532@skipIfNoTorchVision
4533class TestVisionTracing(JitTestCase):
4534    def setUp(self):
4535        # Checking for mutable operations while tracing is feature flagged
4536        # Enable it in testing but not by default
4537        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4538        torch.fx.proxy.TracerBase.check_mutable_operations = True
4539
4540    def tearDown(self):
4541        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4542
4543    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4544    INCONSISTENT_TYPE = (
4545        RuntimeError,
4546        r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
4547    )
4548
4549    UNTRACEABLE_MODELS = {
4550        "fasterrcnn_resnet50_fpn": PROXY_ITERATED,
4551        "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4552        "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
4553        "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
4554        "maskrcnn_resnet50_fpn": PROXY_ITERATED,
4555        "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4556        "keypointrcnn_resnet50_fpn": PROXY_ITERATED,
4557        "retinanet_resnet50_fpn": PROXY_ITERATED,
4558        "retinanet_resnet50_fpn_v2": PROXY_ITERATED,
4559        "ssd300_vgg16": PROXY_ITERATED,
4560        "fcos_resnet50_fpn": PROXY_ITERATED,
4561        "ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
4562    }
4563    UNSCRIPTABLE_MODELS = {
4564        "googlenet": INCONSISTENT_TYPE,
4565        "inception_v3": INCONSISTENT_TYPE,
4566    }
4567
4568    output_transform = {
4569        "fcn_resnet50": lambda x: x["out"],
4570        "fcn_resnet101": lambda x: x["out"],
4571        "deeplabv3_resnet50": lambda x: x["out"],
4572        "deeplabv3_resnet101": lambda x: x["out"],
4573        "deeplabv3_mobilenet_v3_large": lambda x: x["out"],
4574        "lraspp_mobilenet_v3_large": lambda x: x["out"],
4575        "fasterrcnn_resnet50_fpn": lambda x: x[1],
4576        "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4577        "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4578        "maskrcnn_resnet50_fpn": lambda x: x[1],
4579        "keypointrcnn_resnet50_fpn": lambda x: x[1],
4580        "retinanet_resnet50_fpn": lambda x: x[1],
4581    }
4582
4583    @classmethod
4584    def generate_test_fn(cls, name, x, kwargs):
4585        def run_test(self):
4586            model = torchvision_models.get_model(name, **kwargs)
4587            model = model.eval()
4588            if name in self.UNTRACEABLE_MODELS:
4589                err, exc = self.UNTRACEABLE_MODELS[name]
4590                with self.assertRaisesRegex(err, exc):
4591                    graph = symbolic_trace(model)
4592            else:
4593                out_transform = self.output_transform.get(name, lambda x: x)
4594                graph : torch.fx.GraphModule = symbolic_trace(model)
4595                a = out_transform(model(x))
4596                b = out_transform(graph(x))
4597                self.assertEqual(a, b)
4598
4599                if name in self.UNSCRIPTABLE_MODELS:
4600                    err, exc = self.UNSCRIPTABLE_MODELS[name]
4601                    with self.assertRaisesRegex(err, exc):
4602                        script = torch.jit.script(graph)
4603                else:
4604                    script = torch.jit.script(graph)
4605                    c = out_transform(script(x))
4606                    self.assertEqual(a, c)
4607
4608        return run_test
4609
4610    @classmethod
4611    def generate_classification_tests(cls):
4612        for k in torchvision_models.list_models(module=torchvision_models):
4613            test_name = 'test_torchvision_models_' + k
4614            x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
4615            kwargs = dict(num_classes=50)
4616            model_test = cls.generate_test_fn(k, x, kwargs)
4617            setattr(cls, test_name, model_test)
4618
4619    @classmethod
4620    def generate_segmentation_tests(cls):
4621        for k in torchvision_models.list_models(module=torchvision_models.segmentation):
4622            test_name = 'test_torchvision_models_segmentation_' + k
4623            x = torch.rand(1, 3, 32, 32)
4624            kwargs = dict(num_classes=10, pretrained_backbone=False)
4625            model_test = cls.generate_test_fn(k, x, kwargs)
4626            setattr(cls, test_name, model_test)
4627
4628    @classmethod
4629    def generate_detection_tests(cls):
4630        for k in torchvision_models.list_models(module=torchvision_models.detection):
4631            test_name = 'test_torchvision_models_detection_' + k
4632            x = [torch.rand(3, 300, 300)]
4633            kwargs = dict(num_classes=10, pretrained_backbone=False)
4634            model_test = cls.generate_test_fn(k, x, kwargs)
4635            setattr(cls, test_name, model_test)
4636
4637    @classmethod
4638    def generate_video_tests(cls):
4639        for k in torchvision_models.list_models(module=torchvision_models.video):
4640            test_name = 'test_torchvision_models_video_' + k
4641            x = (
4642                torch.rand(1, 3, 4, 112, 112)
4643                if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
4644                else torch.rand(1, 3, 16, 224, 224)
4645            )
4646            kwargs = dict(num_classes=50)
4647            model_test = cls.generate_test_fn(k, x, kwargs)
4648            setattr(cls, test_name, model_test)
4649
4650    @classmethod
4651    def generate_tests(cls):
4652        cls.generate_classification_tests()
4653        cls.generate_detection_tests()
4654        cls.generate_segmentation_tests()
4655        cls.generate_video_tests()
4656
4657if HAS_TORCHVISION:
4658    TestVisionTracing.generate_tests()
4659
4660if __name__ == '__main__':
4661    run_tests()
4662