xref: /aosp_15_r20/external/pytorch/test/test_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport builtins
4*da0073e9SAndroid Build Coastguard Workerimport contextlib
5*da0073e9SAndroid Build Coastguard Workerimport copy
6*da0073e9SAndroid Build Coastguard Workerimport functools
7*da0073e9SAndroid Build Coastguard Workerimport inspect
8*da0073e9SAndroid Build Coastguard Workerimport math
9*da0073e9SAndroid Build Coastguard Workerimport numbers
10*da0073e9SAndroid Build Coastguard Workerimport io
11*da0073e9SAndroid Build Coastguard Workerimport operator
12*da0073e9SAndroid Build Coastguard Workerimport os
13*da0073e9SAndroid Build Coastguard Workerimport pickle
14*da0073e9SAndroid Build Coastguard Workerimport sys
15*da0073e9SAndroid Build Coastguard Workerimport torch
16*da0073e9SAndroid Build Coastguard Workerimport traceback
17*da0073e9SAndroid Build Coastguard Workerimport typing
18*da0073e9SAndroid Build Coastguard Workerimport types
19*da0073e9SAndroid Build Coastguard Workerimport warnings
20*da0073e9SAndroid Build Coastguard Workerimport unittest
21*da0073e9SAndroid Build Coastguard Workerfrom math import sqrt
22*da0073e9SAndroid Build Coastguard Workerfrom functorch.experimental import control_flow
23*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing import Process
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
28*da0073e9SAndroid Build Coastguard Workerimport torch.fx._pytree as fx_pytree
29*da0073e9SAndroid Build Coastguard Workerfrom torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
30*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.node import Target, Argument, _format_arg
31*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes import shape_prop
32*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.immutable_collections import immutable_dict, immutable_list
33*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.rewriter import RewritingTracer
34*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import get_signature_for_torch_op
35*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
36*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.proxy import TraceError
39*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
40*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._symbolic_trace import PHBase, PHWithMeta
41*da0073e9SAndroid Build Coastguard Workerfrom fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
42*da0073e9SAndroid Build Coastguard Workerfrom fx.test_dce_pass import TestDCE  # noqa: F401
43*da0073e9SAndroid Build Coastguard Workerfrom fx.test_fx_const_fold import TestConstFold  # noqa: F401
44*da0073e9SAndroid Build Coastguard Workerfrom fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow  # noqa: F401
45*da0073e9SAndroid Build Coastguard Workerfrom fx.test_pass_infra import TestPassManager  # noqa: F401
46*da0073e9SAndroid Build Coastguard Workerfrom fx.test_common_passes import TestCommonPass  # noqa: F401
47*da0073e9SAndroid Build Coastguard Workerfrom fx.test_cse_pass import TestCSEPass  # noqa: F401
48*da0073e9SAndroid Build Coastguard Workerfrom fx.test_matcher_utils import TestMatcher  # noqa: F401
49*da0073e9SAndroid Build Coastguard Workerfrom fx.test_source_matcher_utils import TestSourceMatcher  # noqa: F401
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerfrom fx.test_gradual_type import AnnotationsTest  # noqa: F401
52*da0073e9SAndroid Build Coastguard Workerfrom fx.test_gradual_type import TypeCheckerTest  # noqa: F401
53*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
55*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
56*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
57*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
58*da0073e9SAndroid Build Coastguard Worker    find_library_location,
59*da0073e9SAndroid Build Coastguard Worker    run_tests,
60*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
61*da0073e9SAndroid Build Coastguard Worker)
62*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerfrom fx.named_tup import MyNamedTup
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Workertry:
67*da0073e9SAndroid Build Coastguard Worker    from torchvision import models as torchvision_models
68*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
69*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
70*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
71*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import skipIfNoDynamoSupport
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Workerclass SimpleTest(torch.nn.Module):
75*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
76*da0073e9SAndroid Build Coastguard Worker        return torch.relu(x + 3.0)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Workerdef a_non_torch_leaf(a, b):
79*da0073e9SAndroid Build Coastguard Worker    return a + b
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker# Used for test_autowrap_function. Autowrapped functions need to be global
82*da0073e9SAndroid Build Coastguard Workerdef fx_int(x: float) -> int:
83*da0073e9SAndroid Build Coastguard Worker    return int(x)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef fx_int_x2(x: float) -> int:
86*da0073e9SAndroid Build Coastguard Worker    return int(x) * 2
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker# used in test_pytree. It's all the way out here because pickling a GraphModule
89*da0073e9SAndroid Build Coastguard Worker# that uses Point errors out if Point is local to the function
90*da0073e9SAndroid Build Coastguard WorkerPoint = namedtuple('Point', ['x', 'y'])
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker# Test wrap() passing both a function name as well as a function
93*da0073e9SAndroid Build Coastguard Worker# directly
94*da0073e9SAndroid Build Coastguard Workerdef a_lifted_leaf(a, b):
95*da0073e9SAndroid Build Coastguard Worker    return a[0] + a[1] + b
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workerwrap('a_lifted_leaf')
98*da0073e9SAndroid Build Coastguard Worker# Test wrapping twice doesn't break anything
99*da0073e9SAndroid Build Coastguard Workerwrap('a_lifted_leaf')
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerdef a_lifted_leaf2(a, b):
102*da0073e9SAndroid Build Coastguard Worker    return a[0] + a[1] + b
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerwrap(a_lifted_leaf2)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Workerwrap('len')
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Workerwrap('getattr')
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerdef wrapped_named_tup(p1, *, p2):
111*da0073e9SAndroid Build Coastguard Worker    return p1.x + p2.y
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Workerwrap(wrapped_named_tup)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker@wrap
116*da0073e9SAndroid Build Coastguard Workerdef wrapped_via_decorator(a):
117*da0073e9SAndroid Build Coastguard Worker    return a + 1
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Workerwrap('wrapped_with_submodule')
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Workerdef wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
122*da0073e9SAndroid Build Coastguard Worker    return batchnorm1d(x)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Workerdef my_decorator(f):
125*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(f)
126*da0073e9SAndroid Build Coastguard Worker    def wrapper_inside_decorator(*args, **kwargs):
127*da0073e9SAndroid Build Coastguard Worker        return f(*args, **kwargs)
128*da0073e9SAndroid Build Coastguard Worker    return wrapper_inside_decorator
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker@wrap
131*da0073e9SAndroid Build Coastguard Worker@my_decorator
132*da0073e9SAndroid Build Coastguard Workerdef wrapped_decorated_fn(x):
133*da0073e9SAndroid Build Coastguard Worker    return x
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Workerreal_wrapped_via_decorator = wrapped_via_decorator
136*da0073e9SAndroid Build Coastguard Workerreal_a_lifed_leaf = a_lifted_leaf
137*da0073e9SAndroid Build Coastguard Workerreal_a_lifed_leaf2 = a_lifted_leaf2
138*da0073e9SAndroid Build Coastguard Worker_sqrt = sqrt
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Workerwrap('wrapper_fn')
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Workerdef wrapper_fn(x):
143*da0073e9SAndroid Build Coastguard Worker    return torch.foo(x)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Workerclass Pair(NamedTuple):
146*da0073e9SAndroid Build Coastguard Worker    x : torch.Tensor
147*da0073e9SAndroid Build Coastguard Worker    y : torch.Tensor
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    def _custom_fx_repr_fn(self) -> str:
150*da0073e9SAndroid Build Coastguard Worker        return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker# for testing pytrees
153*da0073e9SAndroid Build Coastguard Workerclass Foo:  # noqa: B209
154*da0073e9SAndroid Build Coastguard Worker    def __init__(self, a, b):
155*da0073e9SAndroid Build Coastguard Worker        self.a = a
156*da0073e9SAndroid Build Coastguard Worker        self.b = b
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Workerclass Add(torch.nn.Module):
159*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
160*da0073e9SAndroid Build Coastguard Worker        return x + x
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker@torch.fx.has_side_effect
163*da0073e9SAndroid Build Coastguard Worker@torch.fx.wrap
164*da0073e9SAndroid Build Coastguard Workerdef side_effect_func(x: torch.Tensor):
165*da0073e9SAndroid Build Coastguard Worker    print(x)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Workerclass TestFX(JitTestCase):
168*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
169*da0073e9SAndroid Build Coastguard Worker        super().setUp()
170*da0073e9SAndroid Build Coastguard Worker        # Checking for mutable operations whil tracing is feature flagged
171*da0073e9SAndroid Build Coastguard Worker        # Enable it in testing but not by default
172*da0073e9SAndroid Build Coastguard Worker        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
173*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = True
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
176*da0073e9SAndroid Build Coastguard Worker            lib_file_path = find_library_location('libtorchbind_test.so')
177*da0073e9SAndroid Build Coastguard Worker            torch.ops.load_library(str(lib_file_path))
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
180*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
181*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
184*da0073e9SAndroid Build Coastguard Worker        """Check that an nn.Module's results match the GraphModule version
185*da0073e9SAndroid Build Coastguard Worker        for a given set of args/kwargs.
186*da0073e9SAndroid Build Coastguard Worker        """
187*da0073e9SAndroid Build Coastguard Worker        kwargs = kwargs if kwargs else {}
188*da0073e9SAndroid Build Coastguard Worker        ref_outs = m(*args, **kwargs)
189*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(m)
190*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
191*da0073e9SAndroid Build Coastguard Worker        test_outs = gm(*args, **kwargs)
192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_outs, test_outs)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    def test_graph_module(self):
195*da0073e9SAndroid Build Coastguard Worker        class MySub(torch.nn.Module):
196*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
197*da0073e9SAndroid Build Coastguard Worker                super().__init__()
198*da0073e9SAndroid Build Coastguard Worker                self.w = torch.nn.Parameter(torch.rand(4, 3))
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
201*da0073e9SAndroid Build Coastguard Worker                return self.w + x
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
204*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
205*da0073e9SAndroid Build Coastguard Worker                super().__init__()
206*da0073e9SAndroid Build Coastguard Worker                self.lin = torch.nn.Linear(4, 3)
207*da0073e9SAndroid Build Coastguard Worker                self.sub_mod = MySub()
208*da0073e9SAndroid Build Coastguard Worker                self.w = torch.nn.Parameter(torch.rand(3))
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker            def forward(self, A, B, c):
211*da0073e9SAndroid Build Coastguard Worker                t = torch.sigmoid(A) + self.lin(c)
212*da0073e9SAndroid Build Coastguard Worker                return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
215*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(m)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(gm)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        class M2(torch.nn.Module):
220*da0073e9SAndroid Build Coastguard Worker            def forward(self, A):
221*da0073e9SAndroid Build Coastguard Worker                m, idx = torch.max(A, 0)
222*da0073e9SAndroid Build Coastguard Worker                return m + 1, idx + 1
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        m2 = M2()
225*da0073e9SAndroid Build Coastguard Worker        gm2 = symbolic_trace(m2)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        class T(torch.nn.Module):
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker            def forward(self, A, b=4, *args, c=5, **kwargs):
230*da0073e9SAndroid Build Coastguard Worker                x = A + 1 + args[0] + kwargs['3']
231*da0073e9SAndroid Build Coastguard Worker                return x
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        t = T()
234*da0073e9SAndroid Build Coastguard Worker        symbolic_trace(t)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        # test for issue described at https://github.com/pytorch/pytorch/issues/63883
237*da0073e9SAndroid Build Coastguard Worker        class M3(torch.nn.Module):
238*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
239*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        m3 = M3()
242*da0073e9SAndroid Build Coastguard Worker        gm3 = symbolic_trace(m3)
243*da0073e9SAndroid Build Coastguard Worker        new_instance = gm3.__new__(type(gm3))
244*da0073e9SAndroid Build Coastguard Worker        new_instance.__init__(gm3, gm3.graph)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 3)
247*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(new_instance(x), torch.relu(x))
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    def test_informative_co_filename(self):
250*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
251*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
252*da0073e9SAndroid Build Coastguard Worker                return a * 2
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(MyModule())
255*da0073e9SAndroid Build Coastguard Worker        self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    def test_custom_import(self):
258*da0073e9SAndroid Build Coastguard Worker        graph = torch.fx.Graph()
259*da0073e9SAndroid Build Coastguard Worker        a = graph.placeholder('x')
260*da0073e9SAndroid Build Coastguard Worker        b = graph.placeholder('y')
261*da0073e9SAndroid Build Coastguard Worker        c = graph.call_function(a_non_torch_leaf, (a, b))
262*da0073e9SAndroid Build Coastguard Worker        d = graph.call_function(torch.sin, (c,))
263*da0073e9SAndroid Build Coastguard Worker        graph.output(d)
264*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(torch.nn.Module(), graph)
265*da0073e9SAndroid Build Coastguard Worker        x, y = torch.rand(1), torch.rand(1)
266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sin(x + y), gm(x, y))
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    def test_args_kwargs(self):
269*da0073e9SAndroid Build Coastguard Worker        class T(torch.nn.Module):
270*da0073e9SAndroid Build Coastguard Worker            def forward(self, *args, **kwargs):
271*da0073e9SAndroid Build Coastguard Worker                x = args[0] + kwargs['foo']
272*da0073e9SAndroid Build Coastguard Worker                return x
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        t = T()
275*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    def test_varargs_concrete(self):
278*da0073e9SAndroid Build Coastguard Worker        class T(torch.nn.Module):
279*da0073e9SAndroid Build Coastguard Worker            def forward(self, *args, **kwargs):
280*da0073e9SAndroid Build Coastguard Worker                x = args[0] + args[1]
281*da0073e9SAndroid Build Coastguard Worker                return x
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker        args = (torch.rand(1), torch.rand(1))
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        t = T()
286*da0073e9SAndroid Build Coastguard Worker        ref_outs = t(*args)
287*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
288*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
289*da0073e9SAndroid Build Coastguard Worker        test_outs = gm(*args)
290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_outs, test_outs)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker    def test_args_kwargs_no_self(self):
293*da0073e9SAndroid Build Coastguard Worker        class T(torch.nn.Module):
294*da0073e9SAndroid Build Coastguard Worker            def forward(*args, **kwargs):  # noqa: B902
295*da0073e9SAndroid Build Coastguard Worker                self = args[0]
296*da0073e9SAndroid Build Coastguard Worker                return torch.relu(args[1])
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        t = T()
299*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
300*da0073e9SAndroid Build Coastguard Worker            self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker    def test_fx_shifts(self):
303*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
304*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
305*da0073e9SAndroid Build Coastguard Worker                return x << 3, x >> 3
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        input = torch.LongTensor(10).random_(0, 1024)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
310*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (input,))
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    def test_fx_and_or(self):
313*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
314*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
315*da0073e9SAndroid Build Coastguard Worker                return x & x, x | x
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        input = torch.LongTensor(10).random_(0, 1024)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
320*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (input,))
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker    def test_dict(self):
323*da0073e9SAndroid Build Coastguard Worker        class MyDictMod(torch.nn.Module):
324*da0073e9SAndroid Build Coastguard Worker            def forward(self, d):
325*da0073e9SAndroid Build Coastguard Worker                return d['3'].relu(), {'4' : d['3'].neg()}
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker        input_dict = {'3': torch.rand(3, 4)}
328*da0073e9SAndroid Build Coastguard Worker        m = MyDictMod()
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (input_dict,))
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    def test_matmul_tracing(self):
333*da0073e9SAndroid Build Coastguard Worker        const = torch.randn(3)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        def matmul_f(x):
336*da0073e9SAndroid Build Coastguard Worker            return x @ const
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker        mod = symbolic_trace(matmul_f)
339*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3)
340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod(inp), matmul_f(inp))
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        def rmatmul_f(x):
343*da0073e9SAndroid Build Coastguard Worker            return const @ x
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        mod = symbolic_trace(rmatmul_f)
346*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3)
347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod(inp), rmatmul_f(inp))
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    @skipIfNoDynamoSupport
350*da0073e9SAndroid Build Coastguard Worker    def test_control_flow_tracing(self):
351*da0073e9SAndroid Build Coastguard Worker        def true(x, y):
352*da0073e9SAndroid Build Coastguard Worker            return x + y
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        def false(x, y):
355*da0073e9SAndroid Build Coastguard Worker            return x - y
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
358*da0073e9SAndroid Build Coastguard Worker            x = control_flow.cond(x[0] == 0, true, false, [x, y])
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"):
361*da0073e9SAndroid Build Coastguard Worker            _ = symbolic_trace(f)
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    def test_disallow_override(self):
364*da0073e9SAndroid Build Coastguard Worker        # Custom delegate to disallow in-place tensor operations
365*da0073e9SAndroid Build Coastguard Worker        class NoMutableCallTracer(Tracer):
366*da0073e9SAndroid Build Coastguard Worker            def create_node(self, kind : str, target : Union[str, Callable],
367*da0073e9SAndroid Build Coastguard Worker                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
368*da0073e9SAndroid Build Coastguard Worker                            type_expr : Optional[Any] = None) -> Node:
369*da0073e9SAndroid Build Coastguard Worker                name = target if isinstance(target, str) else torch.typename(target)
370*da0073e9SAndroid Build Coastguard Worker                if name[-1] == '_':
371*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError('In-place operations are not supported')
372*da0073e9SAndroid Build Coastguard Worker                return super().create_node(kind, target, args, kwargs, name)
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        # Test method
375*da0073e9SAndroid Build Coastguard Worker        class MyInplaceMod(torch.nn.Module):
376*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
377*da0073e9SAndroid Build Coastguard Worker                x.add_(3.0)
378*da0073e9SAndroid Build Coastguard Worker                return x
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        m = MyInplaceMod()
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
383*da0073e9SAndroid Build Coastguard Worker            NoMutableCallTracer().trace(m)
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        # Test free function
386*da0073e9SAndroid Build Coastguard Worker        class MyInplaceMod2(torch.nn.Module):
387*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
388*da0073e9SAndroid Build Coastguard Worker                torch.log_(x)
389*da0073e9SAndroid Build Coastguard Worker                return x
390*da0073e9SAndroid Build Coastguard Worker        m2 = MyInplaceMod2()
391*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
392*da0073e9SAndroid Build Coastguard Worker            NoMutableCallTracer().trace(m2)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker        # Test symbolic node as an arg
395*da0073e9SAndroid Build Coastguard Worker        class MyInplaceMod3(torch.nn.Module):
396*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
397*da0073e9SAndroid Build Coastguard Worker                y = torch.ones(3, 4)
398*da0073e9SAndroid Build Coastguard Worker                y.add_(x)
399*da0073e9SAndroid Build Coastguard Worker                return x
400*da0073e9SAndroid Build Coastguard Worker        m3 = MyInplaceMod3()
401*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
402*da0073e9SAndroid Build Coastguard Worker            NoMutableCallTracer().trace(m3)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def test_leaf_module(self):
405*da0073e9SAndroid Build Coastguard Worker        # Custom delegate to make it so that there are no leaf modules, everything
406*da0073e9SAndroid Build Coastguard Worker        # should get traced through
407*da0073e9SAndroid Build Coastguard Worker        class NoLeafModulesTracer(Tracer):
408*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, m, qualname):
409*da0073e9SAndroid Build Coastguard Worker                return False
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        class MyReluMod(torch.nn.Module):
412*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
413*da0073e9SAndroid Build Coastguard Worker                super().__init__()
414*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
417*da0073e9SAndroid Build Coastguard Worker                return self.relu(x)
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        mrm = MyReluMod()
420*da0073e9SAndroid Build Coastguard Worker        sym = NoLeafModulesTracer().trace(mrm)
421*da0073e9SAndroid Build Coastguard Worker        for node in sym.nodes:
422*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(node.op, 'call_module')
423*da0073e9SAndroid Build Coastguard Worker        sym.lint()
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    def test_wrap(self):
426*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
429*da0073e9SAndroid Build Coastguard Worker            return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(to_trace)
432*da0073e9SAndroid Build Coastguard Worker        self.assertIn('a_lifted_leaf', m.code)
433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(27, m(2))
434*da0073e9SAndroid Build Coastguard Worker        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    def test_wrap_fn_directly(self):
437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
440*da0073e9SAndroid Build Coastguard Worker            return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(to_trace)
443*da0073e9SAndroid Build Coastguard Worker        self.assertIn('a_lifted_leaf2', m.code)
444*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(27, m(2))
445*da0073e9SAndroid Build Coastguard Worker        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker    def test_wrapped_via_decorator(self):
448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wrapped_via_decorator(0), 1)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
451*da0073e9SAndroid Build Coastguard Worker            return wrapped_via_decorator(y)
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(to_trace)
454*da0073e9SAndroid Build Coastguard Worker        self.assertIn('wrapped_via_decorator', m.code)
455*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(0), 1)
456*da0073e9SAndroid Build Coastguard Worker        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
457*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker    def test_wrapped_via_decorator_and_transformed(self):
460*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wrapped_via_decorator(0), 1)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
463*da0073e9SAndroid Build Coastguard Worker            return wrapped_via_decorator(y)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(to_trace)
466*da0073e9SAndroid Build Coastguard Worker        self.assertIn('wrapped_via_decorator', m.code)
467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(0), 1)
468*da0073e9SAndroid Build Coastguard Worker        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
469*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        transformed = torch.fx.Transformer(m).transform()
472*da0073e9SAndroid Build Coastguard Worker        self.assertIn('wrapped_via_decorator', transformed.code)
473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(transformed(0), 1)
474*da0073e9SAndroid Build Coastguard Worker        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
475*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    def test_wrap_with_submodule(self):
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
480*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
481*da0073e9SAndroid Build Coastguard Worker                super().__init__()
482*da0073e9SAndroid Build Coastguard Worker                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor):
485*da0073e9SAndroid Build Coastguard Worker                return wrapped_with_submodule(x, self.batchnorm1d)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(M())
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker        self.assertIn("wrapped_with_submodule", m.code)
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 2)
492*da0073e9SAndroid Build Coastguard Worker        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_batchnorm1d(input), m(input))
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker    def test_wrapped_retrace(self):
496*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
497*da0073e9SAndroid Build Coastguard Worker            return wrapped_via_decorator(y)
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(to_trace)
500*da0073e9SAndroid Build Coastguard Worker        self.assertIn('wrapped_via_decorator', m.code)
501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(0), 1)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        retraced = symbolic_trace(m)
504*da0073e9SAndroid Build Coastguard Worker        self.assertIn('wrapped_via_decorator', retraced.code)
505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(retraced(0), 1)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker    def test_wrap_decorated_function(self):
508*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
509*da0073e9SAndroid Build Coastguard Worker            return wrapped_decorated_fn(y)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(to_trace)
512*da0073e9SAndroid Build Coastguard Worker        self.assertIn('wrapped_decorated_fn', m.code)
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(1), 1)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    def test_graph_edit_with_proxy(self):
516*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
517*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
518*da0073e9SAndroid Build Coastguard Worker                return a + b
519*da0073e9SAndroid Build Coastguard Worker        m = M()
520*da0073e9SAndroid Build Coastguard Worker        g = symbolic_trace(m).graph
521*da0073e9SAndroid Build Coastguard Worker        new_g = torch.fx.Graph()
522*da0073e9SAndroid Build Coastguard Worker        val_map : Dict[Node, Node] = {}
523*da0073e9SAndroid Build Coastguard Worker        output_val = new_g.graph_copy(g, val_map)
524*da0073e9SAndroid Build Coastguard Worker        t = Proxy(output_val)
525*da0073e9SAndroid Build Coastguard Worker        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
526*da0073e9SAndroid Build Coastguard Worker        new_g.output((t + t).node)
527*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(m, new_g)
528*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(3, 4), 14)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker    def test_proxy_deepcopy_without_tracer(self):
532*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
533*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
534*da0073e9SAndroid Build Coastguard Worker                super().__init__()
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
537*da0073e9SAndroid Build Coastguard Worker                return 2 * x
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        module = MyModule()
540*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(module)
541*da0073e9SAndroid Build Coastguard Worker        node = list(traced.graph.nodes)[-2]
542*da0073e9SAndroid Build Coastguard Worker        p = torch.fx.Proxy(node, None)
543*da0073e9SAndroid Build Coastguard Worker        node.proxy = p
544*da0073e9SAndroid Build Coastguard Worker        p2 = copy.deepcopy(p)
545*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(p2, torch.fx.Proxy))
546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p2.node.name, node.name)
547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p2.node.target, node.target)
548*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(id(p2.node), id(node))
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    def test_proxy_deepcopy_with_tracer(self):
551*da0073e9SAndroid Build Coastguard Worker        class TestTracer(Tracer):
552*da0073e9SAndroid Build Coastguard Worker            def __init__(self, name):
553*da0073e9SAndroid Build Coastguard Worker                super().__init__()
554*da0073e9SAndroid Build Coastguard Worker                self.name = name
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, module, name):
557*da0073e9SAndroid Build Coastguard Worker                return True
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
560*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
561*da0073e9SAndroid Build Coastguard Worker                super().__init__()
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
564*da0073e9SAndroid Build Coastguard Worker                return 2 * x
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker        module = MyModule()
567*da0073e9SAndroid Build Coastguard Worker        tracer = TestTracer("mytracer")
568*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(module)
569*da0073e9SAndroid Build Coastguard Worker        node = list(traced.graph.nodes)[-2]
570*da0073e9SAndroid Build Coastguard Worker        p = torch.fx.Proxy(node, tracer)
571*da0073e9SAndroid Build Coastguard Worker        node.proxy = p
572*da0073e9SAndroid Build Coastguard Worker        p2 = copy.deepcopy(p)
573*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(p2, torch.fx.Proxy))
574*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(p2.tracer, torch.fx._symbolic_trace.Tracer))
575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p2.tracer.name, "mytracer")
576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p2.node.name, node.name)
577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p2.node.target, node.target)
578*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(id(p2.node), id(node))
579*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(id(p2.tracer), id(tracer))
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker    def test_concrete_arg_none_assert(self):
582*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
583*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, val=None):
584*da0073e9SAndroid Build Coastguard Worker                return x if val is None else x + val
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker        f = Foo()
587*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
588*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
589*da0073e9SAndroid Build Coastguard Worker            traced(torch.randn(5), torch.randn(5))
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5)
592*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(traced(x), f(x))
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    def test_trace_multiple_funcs(self):
595*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
596*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
597*da0073e9SAndroid Build Coastguard Worker                return x + y
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker            def minus_forward(self, x, y):
600*da0073e9SAndroid Build Coastguard Worker                return x - y
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker            def multiply_forward(self, x, y):
603*da0073e9SAndroid Build Coastguard Worker                return x * y
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        f = Foo()
606*da0073e9SAndroid Build Coastguard Worker        x, y = torch.randn(5), torch.randn(5)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        print(torch.__version__)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        tracer = Tracer()
611*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        tracer.traced_func_name = "minus_forward"
614*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
615*da0073e9SAndroid Build Coastguard Worker            GraphModule(f, tracer.trace(f))(x, y),
616*da0073e9SAndroid Build Coastguard Worker            f.minus_forward(x, y),
617*da0073e9SAndroid Build Coastguard Worker        )
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        tracer.traced_func_name = "multiply_forward"
620*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
621*da0073e9SAndroid Build Coastguard Worker            GraphModule(f, tracer.trace(f))(x, y),
622*da0073e9SAndroid Build Coastguard Worker            f.multiply_forward(x, y),
623*da0073e9SAndroid Build Coastguard Worker        )
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        tracer.traced_func_name = "add_forward"
626*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
627*da0073e9SAndroid Build Coastguard Worker            tracer.trace(f)
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker    def test_graph_unique_names(self):
630*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
631*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
632*da0073e9SAndroid Build Coastguard Worker                return a + b
633*da0073e9SAndroid Build Coastguard Worker        m = M()
634*da0073e9SAndroid Build Coastguard Worker        g = symbolic_trace(m).graph
635*da0073e9SAndroid Build Coastguard Worker        new_g = torch.fx.Graph()
636*da0073e9SAndroid Build Coastguard Worker        val_map : Dict[Node, Node] = {}
637*da0073e9SAndroid Build Coastguard Worker        output_val = new_g.graph_copy(g, val_map)
638*da0073e9SAndroid Build Coastguard Worker        t = Proxy(output_val)
639*da0073e9SAndroid Build Coastguard Worker        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
640*da0073e9SAndroid Build Coastguard Worker        new_g.output((t + t).node)
641*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(m, new_g)
642*da0073e9SAndroid Build Coastguard Worker        seen_names : Set[str] = set()
643*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
644*da0073e9SAndroid Build Coastguard Worker            assert node.name not in seen_names
645*da0073e9SAndroid Build Coastguard Worker            seen_names.add(node.name)
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    def test_stack_traces(self):
648*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
649*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
650*da0073e9SAndroid Build Coastguard Worker                return a + b
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        tracer = torch.fx.Tracer()
653*da0073e9SAndroid Build Coastguard Worker        tracer.record_stack_traces = True
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker        graph = tracer.trace(M())
656*da0073e9SAndroid Build Coastguard Worker        # saving the original list because we will insert new nodes as a part of a test
657*da0073e9SAndroid Build Coastguard Worker        orig_graph_nodes = list(graph.nodes)
658*da0073e9SAndroid Build Coastguard Worker        for node in orig_graph_nodes:
659*da0073e9SAndroid Build Coastguard Worker            if node.op == 'output':
660*da0073e9SAndroid Build Coastguard Worker                continue
661*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(node.stack_trace is not None)
662*da0073e9SAndroid Build Coastguard Worker            assert 'test_fx.py' in node.stack_trace
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker            # verify that copying the node does not lose the stack trace
665*da0073e9SAndroid Build Coastguard Worker            new_node = graph.node_copy(node)
666*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(new_node.stack_trace is not None)
667*da0073e9SAndroid Build Coastguard Worker            assert 'test_fx.py' in new_node.stack_trace
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    def test_stack_traces_with_transformer(self):
670*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
671*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
672*da0073e9SAndroid Build Coastguard Worker                return a + b
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker        tracer = torch.fx.Tracer()
675*da0073e9SAndroid Build Coastguard Worker        tracer.record_stack_traces = True
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        graph = tracer.trace(M())
678*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(tracer.root, graph)
679*da0073e9SAndroid Build Coastguard Worker        new_gm = Transformer(gm).transform()
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        # nodes after Transformer should still preserve the original node's stack trace
682*da0073e9SAndroid Build Coastguard Worker        for node in new_gm.graph.nodes:
683*da0073e9SAndroid Build Coastguard Worker            if node.op in {'placeholder', 'output'}:
684*da0073e9SAndroid Build Coastguard Worker                continue
685*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(node.stack_trace is not None)
686*da0073e9SAndroid Build Coastguard Worker            assert 'test_fx.py' in node.stack_trace
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker    def test_lineno_map(self):
689*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
690*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
691*da0073e9SAndroid Build Coastguard Worker                a = torch.sin(a)
692*da0073e9SAndroid Build Coastguard Worker                b = torch.cos(b)
693*da0073e9SAndroid Build Coastguard Worker                return a + b
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker        tracer = torch.fx.Tracer()
696*da0073e9SAndroid Build Coastguard Worker        graph = tracer.trace(M())
697*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(tracer.root, graph)
698*da0073e9SAndroid Build Coastguard Worker        expected = {1: 2, 2: 3, 3: 4, 4: 5}
699*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        # test custom codegen
702*da0073e9SAndroid Build Coastguard Worker        def transform_code(code):
703*da0073e9SAndroid Build Coastguard Worker            return ["print('hello!')\n", *code]
704*da0073e9SAndroid Build Coastguard Worker        gm.graph.on_generate_code(lambda _: transform_code)
705*da0073e9SAndroid Build Coastguard Worker        gm.recompile()
706*da0073e9SAndroid Build Coastguard Worker        expected = {2: 2, 3: 3, 4: 4, 5: 5}
707*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker    def test_graph_unique_names_manual(self):
710*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
711*da0073e9SAndroid Build Coastguard Worker        a : torch.fx.Node = graph.create_node('placeholder', 'x')
712*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
713*da0073e9SAndroid Build Coastguard Worker        c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
714*da0073e9SAndroid Build Coastguard Worker        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
715*da0073e9SAndroid Build Coastguard Worker        graph.output(d)
716*da0073e9SAndroid Build Coastguard Worker        graph2 = torch.fx.Graph()
717*da0073e9SAndroid Build Coastguard Worker        val_map : Dict[Node, Node] = {}
718*da0073e9SAndroid Build Coastguard Worker        graph2.graph_copy(graph, val_map)
719*da0073e9SAndroid Build Coastguard Worker        seen_names : Set[str] = set()
720*da0073e9SAndroid Build Coastguard Worker        for node in graph2.nodes:
721*da0073e9SAndroid Build Coastguard Worker            assert node.name not in seen_names
722*da0073e9SAndroid Build Coastguard Worker            seen_names.add(node.name)
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker    def test_unpack(self):
725*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
726*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
727*da0073e9SAndroid Build Coastguard Worker                c, d = a
728*da0073e9SAndroid Build Coastguard Worker                return c + d + b
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker        a = (torch.rand(1), torch.rand(1))
731*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1)
732*da0073e9SAndroid Build Coastguard Worker        m = M()
733*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (a, b))
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def test_native_callable(self):
736*da0073e9SAndroid Build Coastguard Worker        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
737*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("non-portable load_library call used in test")
738*da0073e9SAndroid Build Coastguard Worker        # This test exercises the case where we use FX to translate from Python
739*da0073e9SAndroid Build Coastguard Worker        # code to some native callable object
740*da0073e9SAndroid Build Coastguard Worker        #
741*da0073e9SAndroid Build Coastguard Worker        # For the purposes of testing, we use ElementwiseInterpreter defined
742*da0073e9SAndroid Build Coastguard Worker        # in test_custom_class.cpp.
743*da0073e9SAndroid Build Coastguard Worker        #
744*da0073e9SAndroid Build Coastguard Worker        # We test that we can
745*da0073e9SAndroid Build Coastguard Worker        # 1) Construct a native callable from FX IR
746*da0073e9SAndroid Build Coastguard Worker        # 2) Construct a drop-in replacement module that delegates to the
747*da0073e9SAndroid Build Coastguard Worker        #    native callable rather than the original code
748*da0073e9SAndroid Build Coastguard Worker        # 3) Run both the original code and native callable wrapper with
749*da0073e9SAndroid Build Coastguard Worker        #    equivalent results
750*da0073e9SAndroid Build Coastguard Worker        # 4) TorchScript compile the native callable wrapper and confirm
751*da0073e9SAndroid Build Coastguard Worker        #    equivalent results with the reference
752*da0073e9SAndroid Build Coastguard Worker        # 5) TorchScript serialize and deserialize the native callable
753*da0073e9SAndroid Build Coastguard Worker        #    and confirm equivalent results with the reference
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker        # We use this simple Module as a reference computation
756*da0073e9SAndroid Build Coastguard Worker        class MySimpleMod(torch.nn.Module):
757*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
758*da0073e9SAndroid Build Coastguard Worker                return 3.0 * x + x
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        msm = MySimpleMod()
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker        # This is what a lowering pass might look like: a function that takes
763*da0073e9SAndroid Build Coastguard Worker        # a valid nn.Module, symbolically traces it, lowers the Module to some
764*da0073e9SAndroid Build Coastguard Worker        # representation, and wraps that representation up into another
765*da0073e9SAndroid Build Coastguard Worker        # nn.Module instance that handles dispatch to the compiled/lowered code.
766*da0073e9SAndroid Build Coastguard Worker        def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
767*da0073e9SAndroid Build Coastguard Worker            # ===== Stage 1: Symbolic trace the module =====
768*da0073e9SAndroid Build Coastguard Worker            mod = symbolic_trace(orig_mod)
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker            # ===== Stage 2: Lower GraphModule representation to the C++
771*da0073e9SAndroid Build Coastguard Worker            #       interpreter's instruction format ======
772*da0073e9SAndroid Build Coastguard Worker            instructions = []
773*da0073e9SAndroid Build Coastguard Worker            constant_idx = 0
774*da0073e9SAndroid Build Coastguard Worker            constants = {}
775*da0073e9SAndroid Build Coastguard Worker            fn_input_names = []
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker            target_to_name = {
778*da0073e9SAndroid Build Coastguard Worker                operator.add : "add",
779*da0073e9SAndroid Build Coastguard Worker                operator.mul : "mul"
780*da0073e9SAndroid Build Coastguard Worker            }
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker            output_node : Optional[Node] = None
783*da0073e9SAndroid Build Coastguard Worker            # For each instruction, create a triple
784*da0073e9SAndroid Build Coastguard Worker            # (instruction_name : str, inputs : List[str], output : str)
785*da0073e9SAndroid Build Coastguard Worker            # to feed into the C++ interpreter
786*da0073e9SAndroid Build Coastguard Worker            for n in mod.graph.nodes:
787*da0073e9SAndroid Build Coastguard Worker                target, args, out_name = n.target, n.args, n.name
788*da0073e9SAndroid Build Coastguard Worker                assert len(n.kwargs) == 0, "kwargs currently not supported"
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker                if n.op == 'placeholder':
791*da0073e9SAndroid Build Coastguard Worker                    # Placeholders specify function argument names. Save these
792*da0073e9SAndroid Build Coastguard Worker                    # for later when we generate the wrapper GraphModule
793*da0073e9SAndroid Build Coastguard Worker                    fn_input_names.append(target)
794*da0073e9SAndroid Build Coastguard Worker                elif n.op == 'call_function':
795*da0073e9SAndroid Build Coastguard Worker                    assert target in target_to_name, "Unsupported call target " + target
796*da0073e9SAndroid Build Coastguard Worker                    arg_names = []
797*da0073e9SAndroid Build Coastguard Worker                    for arg in args:
798*da0073e9SAndroid Build Coastguard Worker                        if not isinstance(arg, Node):
799*da0073e9SAndroid Build Coastguard Worker                            # Pull out constants. These constants will later be
800*da0073e9SAndroid Build Coastguard Worker                            # fed to the interpreter C++ object via add_constant()
801*da0073e9SAndroid Build Coastguard Worker                            arg_name = f'constant_{constant_idx}'
802*da0073e9SAndroid Build Coastguard Worker                            constants[arg_name] = torch.tensor(
803*da0073e9SAndroid Build Coastguard Worker                                [arg] if isinstance(arg, numbers.Number) else arg)
804*da0073e9SAndroid Build Coastguard Worker                            arg_names.append(arg_name)
805*da0073e9SAndroid Build Coastguard Worker                            constant_idx += 1
806*da0073e9SAndroid Build Coastguard Worker                        else:
807*da0073e9SAndroid Build Coastguard Worker                            arg_names.append(arg.name)
808*da0073e9SAndroid Build Coastguard Worker                    instructions.append((target_to_name[target], arg_names, out_name))
809*da0073e9SAndroid Build Coastguard Worker                elif n.op == 'output':
810*da0073e9SAndroid Build Coastguard Worker                    if output_node is not None:
811*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError('Multiple output nodes!')
812*da0073e9SAndroid Build Coastguard Worker                    output_node = n
813*da0073e9SAndroid Build Coastguard Worker                else:
814*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError('Unsupported opcode ' + n.op)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
817*da0073e9SAndroid Build Coastguard Worker            # Load constants
818*da0073e9SAndroid Build Coastguard Worker            for k, v in constants.items():
819*da0073e9SAndroid Build Coastguard Worker                interpreter.add_constant(k, v)
820*da0073e9SAndroid Build Coastguard Worker            # Specify names for positional input arguments
821*da0073e9SAndroid Build Coastguard Worker            interpreter.set_input_names(fn_input_names)
822*da0073e9SAndroid Build Coastguard Worker            # Load instructions
823*da0073e9SAndroid Build Coastguard Worker            interpreter.set_instructions(instructions)
824*da0073e9SAndroid Build Coastguard Worker            # Specify name for single output
825*da0073e9SAndroid Build Coastguard Worker            assert isinstance(output_node.args[0], torch.fx.Node)
826*da0073e9SAndroid Build Coastguard Worker            interpreter.set_output_name(output_node.args[0].name)
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
829*da0073e9SAndroid Build Coastguard Worker            class WrapperModule(torch.nn.Module):
830*da0073e9SAndroid Build Coastguard Worker                def __init__(self, interpreter):
831*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
832*da0073e9SAndroid Build Coastguard Worker                    self.interpreter = interpreter
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker            wrapper = WrapperModule(interpreter)
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
837*da0073e9SAndroid Build Coastguard Worker            # 3) Returns the speficied return value
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
840*da0073e9SAndroid Build Coastguard Worker            # the wrapper with a Tracer that considers the Wrapper instance a root
841*da0073e9SAndroid Build Coastguard Worker            # module, however, I can't get `__call__` exposed on TorchBind classes
842*da0073e9SAndroid Build Coastguard Worker            # without it messing up Python `hasattr` for some reason. More digging
843*da0073e9SAndroid Build Coastguard Worker            # into CPython's implementation of hasattr is probably in order...
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker            graph = torch.fx.Graph()
846*da0073e9SAndroid Build Coastguard Worker            # Add placeholders for fn inputs
847*da0073e9SAndroid Build Coastguard Worker            placeholder_nodes = []
848*da0073e9SAndroid Build Coastguard Worker            for name in fn_input_names:
849*da0073e9SAndroid Build Coastguard Worker                placeholder_nodes.append(graph.create_node('placeholder', name))
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker            # Get the interpreter object
852*da0073e9SAndroid Build Coastguard Worker            interpreter_node = graph.create_node('get_attr', 'interpreter')
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker            # Add a node to call the interpreter instance
855*da0073e9SAndroid Build Coastguard Worker            output_node = graph.create_node(
856*da0073e9SAndroid Build Coastguard Worker                op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker            # Register output
859*da0073e9SAndroid Build Coastguard Worker            graph.output(output_node)
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker            graph.lint()
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker            # Return final GraphModule!!!
864*da0073e9SAndroid Build Coastguard Worker            return GraphModule(wrapper, graph)
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker        # Lower GraphModule to C++ interpreter
867*da0073e9SAndroid Build Coastguard Worker        lowered = lower_to_elementwise_interpreter(msm)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker        # Compare correctness with original module
870*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
871*da0073e9SAndroid Build Coastguard Worker        ref_out = msm(x)
872*da0073e9SAndroid Build Coastguard Worker        test_out = lowered(x)
873*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(test_out, ref_out)
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        # Test TorchScript compilation
876*da0073e9SAndroid Build Coastguard Worker        scripted_lowered = torch.jit.script(lowered)
877*da0073e9SAndroid Build Coastguard Worker        script_out = scripted_lowered(x)
878*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(script_out, ref_out)
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker        # Test TorchScript ser/de
881*da0073e9SAndroid Build Coastguard Worker        import_copy = self.getExportImportCopy(scripted_lowered)
882*da0073e9SAndroid Build Coastguard Worker        imported_out = import_copy(x)
883*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(imported_out, ref_out)
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker    def test_reserved_getattr(self):
886*da0073e9SAndroid Build Coastguard Worker        """Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
887*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
888*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
889*da0073e9SAndroid Build Coastguard Worker                return a.foo.bar.baz
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker        m = M()
892*da0073e9SAndroid Build Coastguard Worker        m_g = symbolic_trace(m)
893*da0073e9SAndroid Build Coastguard Worker        m_g.graph.lint()
894*da0073e9SAndroid Build Coastguard Worker        for node in m_g.graph.nodes:
895*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(node.name != "getattr")
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Hotfix for SEV remediation")
898*da0073e9SAndroid Build Coastguard Worker    def test_trace_buffer_slice(self):
899*da0073e9SAndroid Build Coastguard Worker        bs, d_hid = 10, 23
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker        class ExampleCode(torch.nn.Module):
902*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
903*da0073e9SAndroid Build Coastguard Worker                super().__init__()
904*da0073e9SAndroid Build Coastguard Worker                self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
905*da0073e9SAndroid Build Coastguard Worker                self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
906*da0073e9SAndroid Build Coastguard Worker                self.lin = torch.nn.Linear(d_hid, d_hid)
907*da0073e9SAndroid Build Coastguard Worker                self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid))
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
910*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(x, self.mm_param)
911*da0073e9SAndroid Build Coastguard Worker                skip_connection = x
912*da0073e9SAndroid Build Coastguard Worker                x = torch.relu(x)
913*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
914*da0073e9SAndroid Build Coastguard Worker                x = self.lin(x)
915*da0073e9SAndroid Build Coastguard Worker                x = torch.relu(x)
916*da0073e9SAndroid Build Coastguard Worker                x = x + skip_connection
917*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(x, self.mm_param2)
918*da0073e9SAndroid Build Coastguard Worker                x = self.lin(x)
919*da0073e9SAndroid Build Coastguard Worker                return x
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker        ec = ExampleCode()
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(ec)
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(bs, d_hid)
926*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ec(x), traced(x))
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker    def test_node_tagging(self):
929*da0073e9SAndroid Build Coastguard Worker        class TaggingTracer(Tracer):
930*da0073e9SAndroid Build Coastguard Worker            def create_node(self, kind : str, target : Union[str, Callable],
931*da0073e9SAndroid Build Coastguard Worker                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
932*da0073e9SAndroid Build Coastguard Worker                            type_expr : Optional[Any] = None) -> Node:
933*da0073e9SAndroid Build Coastguard Worker                n = super().create_node(kind, target, args, kwargs, name)
934*da0073e9SAndroid Build Coastguard Worker                n.tag = 'foo'
935*da0073e9SAndroid Build Coastguard Worker                return n
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
938*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
939*da0073e9SAndroid Build Coastguard Worker                return a + b
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        m = M()
942*da0073e9SAndroid Build Coastguard Worker        g = TaggingTracer().trace(m)
943*da0073e9SAndroid Build Coastguard Worker        g.lint()
944*da0073e9SAndroid Build Coastguard Worker        for n in g.nodes:
945*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(n, 'tag'))
946*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(n.tag, 'foo')
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker    def test_tensor_attribute(self):
949*da0073e9SAndroid Build Coastguard Worker        class TensorAttribute(torch.nn.Module):
950*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
951*da0073e9SAndroid Build Coastguard Worker                super().__init__()
952*da0073e9SAndroid Build Coastguard Worker                self.tensor = torch.rand(3, 4)
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
955*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.linear(x, self.tensor)
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker        ta = TensorAttribute()
958*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(ta)
959*da0073e9SAndroid Build Coastguard Worker        traced(torch.rand(4, 4))
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker        class WrapperForQualname(torch.nn.Module):
962*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
963*da0073e9SAndroid Build Coastguard Worker                super().__init__()
964*da0073e9SAndroid Build Coastguard Worker                self.ta = TensorAttribute()
965*da0073e9SAndroid Build Coastguard Worker
966*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
967*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.linear(x, self.ta.tensor)
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker        wfq = WrapperForQualname()
970*da0073e9SAndroid Build Coastguard Worker        traced2 = symbolic_trace(wfq)
971*da0073e9SAndroid Build Coastguard Worker        traced2.graph.lint()
972*da0073e9SAndroid Build Coastguard Worker        traced2(torch.rand(4, 4))
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Worker    def test_tensor_attribute_coalseced(self):
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker        def count_attrs(fx_module):
977*da0073e9SAndroid Build Coastguard Worker            targets = set()
978*da0073e9SAndroid Build Coastguard Worker            for node in traced.graph.nodes:
979*da0073e9SAndroid Build Coastguard Worker                if node.op == 'get_attr':
980*da0073e9SAndroid Build Coastguard Worker                    targets.add(node.target)
981*da0073e9SAndroid Build Coastguard Worker            return len(targets)
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker        val = torch.tensor(5)
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker        def f(x):
986*da0073e9SAndroid Build Coastguard Worker            return x + val + val
987*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(f)
988*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count_attrs(traced), 1)
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker        val2 = torch.tensor(5)
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker        def f(x):
994*da0073e9SAndroid Build Coastguard Worker            val = torch.tensor(5)
995*da0073e9SAndroid Build Coastguard Worker            return x + val + val2
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(f)
998*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
999*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count_attrs(traced), 2)
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker    def test_symbolic_trace_sequential(self):
1002*da0073e9SAndroid Build Coastguard Worker        class Simple(torch.nn.Module):
1003*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1004*da0073e9SAndroid Build Coastguard Worker                return torch.neg(x)
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        seq = torch.nn.Sequential(
1007*da0073e9SAndroid Build Coastguard Worker            Simple(),
1008*da0073e9SAndroid Build Coastguard Worker            Simple(),
1009*da0073e9SAndroid Build Coastguard Worker            Simple()
1010*da0073e9SAndroid Build Coastguard Worker        )
1011*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(seq)
1012*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
1013*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
1014*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(x), seq(x))
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker    def test_tensor_constant(self):
1017*da0073e9SAndroid Build Coastguard Worker        class ConstTensor(torch.nn.Module):
1018*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1019*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.linear(x, torch.zeros(3, 4))
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker        ct = ConstTensor()
1022*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(ct)
1023*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
1024*da0073e9SAndroid Build Coastguard Worker        traced(torch.rand(4, 4))
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    def test_pickle_graphmodule(self):
1027*da0073e9SAndroid Build Coastguard Worker        class Nested(torch.nn.Module):
1028*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1029*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1030*da0073e9SAndroid Build Coastguard Worker                self.st = torch.nn.Linear(4, 4)
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1033*da0073e9SAndroid Build Coastguard Worker                return self.st(x)
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker        n = Nested()
1036*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(n)
1037*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
1038*da0073e9SAndroid Build Coastguard Worker        pickled = pickle.dumps(traced)
1039*da0073e9SAndroid Build Coastguard Worker        loaded = pickle.loads(pickled)
1040*da0073e9SAndroid Build Coastguard Worker        loaded.graph.lint()
1041*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
1042*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(x), traced(x))
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker    def test_pickle_custom_import(self):
1045*da0073e9SAndroid Build Coastguard Worker        graph = torch.fx.Graph()
1046*da0073e9SAndroid Build Coastguard Worker        a = graph.placeholder('x')
1047*da0073e9SAndroid Build Coastguard Worker        b = graph.placeholder('y')
1048*da0073e9SAndroid Build Coastguard Worker        c = graph.call_function(a_non_torch_leaf, (a, b))
1049*da0073e9SAndroid Build Coastguard Worker        d = graph.call_function(torch.sin, (c,))
1050*da0073e9SAndroid Build Coastguard Worker        graph.output(d)
1051*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(torch.nn.Module(), graph)
1052*da0073e9SAndroid Build Coastguard Worker        pickled = pickle.dumps(gm)
1053*da0073e9SAndroid Build Coastguard Worker        loaded = pickle.loads(pickled)
1054*da0073e9SAndroid Build Coastguard Worker        loaded.graph.lint()
1055*da0073e9SAndroid Build Coastguard Worker        x, y = torch.rand(1), torch.rand(1)
1056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(x, y), gm(x, y))
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker    def test_all_input_nodes(self):
1059*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
1060*da0073e9SAndroid Build Coastguard Worker        a : torch.fx.Node = graph.placeholder('x')
1061*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
1062*da0073e9SAndroid Build Coastguard Worker        c : torch.fx.Node = graph.get_attr('y_attr')
1063*da0073e9SAndroid Build Coastguard Worker        d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
1064*da0073e9SAndroid Build Coastguard Worker        e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
1065*da0073e9SAndroid Build Coastguard Worker        graph.output(e)
1066*da0073e9SAndroid Build Coastguard Worker        graph.lint()
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.all_input_nodes, [a])
1069*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.all_input_nodes, [])
1070*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.all_input_nodes, [b, c])
1071*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.all_input_nodes, [d])
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_graphmodule_with_transform(self):
1074*da0073e9SAndroid Build Coastguard Worker        st = SimpleTest()
1075*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(st)
1076*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker        def transform(traced):
1079*da0073e9SAndroid Build Coastguard Worker            new_graph = torch.fx.Graph()
1080*da0073e9SAndroid Build Coastguard Worker            val_map : Dict[Node, Node] = {}
1081*da0073e9SAndroid Build Coastguard Worker            output_value = new_graph.graph_copy(traced.graph, val_map)
1082*da0073e9SAndroid Build Coastguard Worker            relu_out = new_graph.create_node(
1083*da0073e9SAndroid Build Coastguard Worker                op='call_method', target='neg', args=(output_value,), kwargs={})
1084*da0073e9SAndroid Build Coastguard Worker            new_graph.output(relu_out)
1085*da0073e9SAndroid Build Coastguard Worker            return GraphModule(traced, new_graph)
1086*da0073e9SAndroid Build Coastguard Worker        transformed = transform(traced)
1087*da0073e9SAndroid Build Coastguard Worker        transformed.graph.lint()
1088*da0073e9SAndroid Build Coastguard Worker        copied = copy.deepcopy(transformed)
1089*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(id(type(transformed)), id(type(copied)))
1090*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
1091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(copied(x), transformed(x))
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_with_submods_params(self):
1094*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.nn.Module):
1095*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1096*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1097*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1100*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x) + self.param
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker        class Baz(torch.nn.Module):
1103*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1104*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1105*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
1106*da0073e9SAndroid Build Coastguard Worker                self.bar = Bar()
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1109*da0073e9SAndroid Build Coastguard Worker                return self.bar(x) - self.param
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker        baz = Baz()
1112*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(baz)
1113*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
1114*da0073e9SAndroid Build Coastguard Worker        copied = copy.deepcopy(traced)
1115*da0073e9SAndroid Build Coastguard Worker        copied.graph.lint()
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_graph_with_tracer_cls(self):
1118*da0073e9SAndroid Build Coastguard Worker        class TestTracer(Tracer):
1119*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, module, name):
1120*da0073e9SAndroid Build Coastguard Worker                return True
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker        g = Graph(tracer_cls=TestTracer)
1123*da0073e9SAndroid Build Coastguard Worker        x = g.placeholder("x")
1124*da0073e9SAndroid Build Coastguard Worker        g.output(x)
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker        h = copy.deepcopy(g)
1127*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(h._tracer_cls)
1128*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(g._tracer_cls == h._tracer_cls)
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker    def test_unpack_list_better_error(self):
1131*da0073e9SAndroid Build Coastguard Worker        class SomeArgs(torch.nn.Module):
1132*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
1133*da0073e9SAndroid Build Coastguard Worker                return torch.rand(3, 4)
1134*da0073e9SAndroid Build Coastguard Worker
1135*da0073e9SAndroid Build Coastguard Worker        class UnpacksList(torch.nn.Module):
1136*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1137*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1138*da0073e9SAndroid Build Coastguard Worker                self.sa = SomeArgs()
1139*da0073e9SAndroid Build Coastguard Worker
1140*da0073e9SAndroid Build Coastguard Worker            def forward(self, x : list):
1141*da0073e9SAndroid Build Coastguard Worker                return self.sa(*x)
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker        ul = UnpacksList()
1144*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1145*da0073e9SAndroid Build Coastguard Worker            symbolic_trace(ul)
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker    def test_unpack_dict_better_error(self):
1148*da0073e9SAndroid Build Coastguard Worker        class SomeKwargs(torch.nn.Module):
1149*da0073e9SAndroid Build Coastguard Worker            def forward(self, x=3, y=4):
1150*da0073e9SAndroid Build Coastguard Worker                return torch.rand(3, 4)
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker        class UnpacksDict(torch.nn.Module):
1153*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1154*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1155*da0073e9SAndroid Build Coastguard Worker                self.sk = SomeKwargs()
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker            def forward(self, x : dict):
1158*da0073e9SAndroid Build Coastguard Worker                return self.sk(**x)
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker        ud = UnpacksDict()
1161*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1162*da0073e9SAndroid Build Coastguard Worker            symbolic_trace(ud)
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker    def test_pretty_print_targets(self):
1165*da0073e9SAndroid Build Coastguard Worker        # Test that Graph pretty-print prints friendly name for targets
1166*da0073e9SAndroid Build Coastguard Worker        # in `operator` and `builtins`
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker        class SomeMod(torch.nn.Module):
1169*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1170*da0073e9SAndroid Build Coastguard Worker                return torch.add(x.foo + x.bar, 3.0)
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(SomeMod())
1173*da0073e9SAndroid Build Coastguard Worker        graph_str = str(traced.graph)
1174*da0073e9SAndroid Build Coastguard Worker        self.assertIn('builtins.getattr', graph_str)
1175*da0073e9SAndroid Build Coastguard Worker        self.assertIn('operator.add', graph_str)
1176*da0073e9SAndroid Build Coastguard Worker        self.assertIn('torch.add', graph_str)
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker    def test_pretty_print_node(self):
1179*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1180*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1181*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1182*da0073e9SAndroid Build Coastguard Worker                self.param: torch.nn.Parameter = torch.nn.Parameter(
1183*da0073e9SAndroid Build Coastguard Worker                    torch.rand(3, 4))
1184*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, y: int = 2):
1187*da0073e9SAndroid Build Coastguard Worker                return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(M())
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker        all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
1192*da0073e9SAndroid Build Coastguard Worker
1193*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("x").check("placeholder") \
1194*da0073e9SAndroid Build Coastguard Worker            .check("y").check("placeholder") \
1195*da0073e9SAndroid Build Coastguard Worker            .check("getitem").check("call_function") \
1196*da0073e9SAndroid Build Coastguard Worker            .check("param").check("get_attr") \
1197*da0073e9SAndroid Build Coastguard Worker            .check("add").check("call_function") \
1198*da0073e9SAndroid Build Coastguard Worker            .check("linear").check("call_module") \
1199*da0073e9SAndroid Build Coastguard Worker            .check("clamp").check("call_method") \
1200*da0073e9SAndroid Build Coastguard Worker            .run(all_formatted)
1201*da0073e9SAndroid Build Coastguard Worker
1202*da0073e9SAndroid Build Coastguard Worker    def test_script_tensor_constant(self):
1203*da0073e9SAndroid Build Coastguard Worker        # TorchScript seems to ignore attributes that start with `__`.
1204*da0073e9SAndroid Build Coastguard Worker        # We used to call anonymous Tensor values `__tensor_constant*`, but
1205*da0073e9SAndroid Build Coastguard Worker        # they were getting ignored by script. Now they're called
1206*da0073e9SAndroid Build Coastguard Worker        # `_tensor_constant*`
1207*da0073e9SAndroid Build Coastguard Worker        class IHaveATensorConstant(torch.nn.Module):
1208*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1209*da0073e9SAndroid Build Coastguard Worker                return x + torch.rand(3, 4)
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(IHaveATensorConstant())
1212*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(traced)
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker    def test_autowrap_functions(self):
1215*da0073e9SAndroid Build Coastguard Worker        class AutowrapFnTest(torch.nn.Module):
1216*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1217*da0073e9SAndroid Build Coastguard Worker                return fx_int(x.shape[0] / 2)
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker        class AutowrapFnTest2(torch.nn.Module):
1220*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1221*da0073e9SAndroid Build Coastguard Worker                return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker        # Check function(s) are wrapped
1224*da0073e9SAndroid Build Coastguard Worker        # `int` would normally throw a TypeError as argument can't be `Proxy`
1225*da0073e9SAndroid Build Coastguard Worker        tracer = Tracer(autowrap_functions=(fx_int,))
1226*da0073e9SAndroid Build Coastguard Worker        graph = tracer.trace(AutowrapFnTest())
1227*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(tracer.root, graph, 'test')
1228*da0073e9SAndroid Build Coastguard Worker        tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
1229*da0073e9SAndroid Build Coastguard Worker        tracer_2.trace(AutowrapFnTest2())
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker        # Test scriptability
1232*da0073e9SAndroid Build Coastguard Worker        traced_scripted = torch.jit.script(traced)
1233*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_scripted(torch.rand(4)), 2)
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker    def test_tuple_no_subscript(self):
1236*da0073e9SAndroid Build Coastguard Worker        def foo(x : Tuple):
1237*da0073e9SAndroid Build Coastguard Worker            return x[0]
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(foo)
1240*da0073e9SAndroid Build Coastguard Worker        x = (torch.randn(5, 3),)
1241*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(traced(x), x[0])
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker        bio = io.BytesIO()
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker        torch.save(traced, bio)
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker        bio.seek(0)
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker        # weights_only=False as this loads a GraphModule
1250*da0073e9SAndroid Build Coastguard Worker        # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
1251*da0073e9SAndroid Build Coastguard Worker        loaded = torch.load(bio, weights_only=False)
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(loaded(x), x[0])
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker    def test_torch_fx_len(self):
1256*da0073e9SAndroid Build Coastguard Worker        class FXLenTest(torch.nn.Module):
1257*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1258*da0073e9SAndroid Build Coastguard Worker                return len(x)
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(FXLenTest())
1261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(torch.rand(3, 4)), 3)
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker        # Test scriptability
1264*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(FXLenTest())
1265*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(torch.rand(3)), 3)
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker        traced_scripted = torch.jit.script(traced)
1268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_scripted(torch.rand(3)), 3)
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker        # Test non-proxy len
1271*da0073e9SAndroid Build Coastguard Worker        class FXLenTest2(torch.nn.Module):
1272*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1273*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1274*da0073e9SAndroid Build Coastguard Worker                self.l = [3, 4, 5]
1275*da0073e9SAndroid Build Coastguard Worker
1276*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1277*da0073e9SAndroid Build Coastguard Worker                return x + len(self.l)
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker        traced2 = symbolic_trace(FXLenTest2())
1280*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(3, 4)
1281*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced2(inp), inp + 3.0)
1282*da0073e9SAndroid Build Coastguard Worker        self.assertIs(len, builtins.len)
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker    def test_torch_fx_getattr(self):
1285*da0073e9SAndroid Build Coastguard Worker        class FXGetattrTest(torch.nn.Module):
1286*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1287*da0073e9SAndroid Build Coastguard Worker                return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(FXGetattrTest())
1290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker    def test_sqrt(self):
1293*da0073e9SAndroid Build Coastguard Worker        class Sqrt1(torch.nn.Module):
1294*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1295*da0073e9SAndroid Build Coastguard Worker                return sqrt(x.size(0))
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker        class Sqrt2(torch.nn.Module):
1298*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1299*da0073e9SAndroid Build Coastguard Worker                return math.sqrt(x.size(0))
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker        class Sqrt3(torch.nn.Module):
1302*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1303*da0073e9SAndroid Build Coastguard Worker                return x + math.sqrt(2) + sqrt(2)
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
1306*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
1307*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
1308*da0073e9SAndroid Build Coastguard Worker        self.assertIs(sqrt, _sqrt)
1309*da0073e9SAndroid Build Coastguard Worker        self.assertIs(math.sqrt, _sqrt)
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker    def test_torch_custom_ops(self):
1312*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1313*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
1314*da0073e9SAndroid Build Coastguard Worker                b = torch.ops.aten.sigmoid(a)
1315*da0073e9SAndroid Build Coastguard Worker                c = torch.ops.aten.cat([a, b])
1316*da0073e9SAndroid Build Coastguard Worker                return torch.ops.aten.cat((c, c))
1317*da0073e9SAndroid Build Coastguard Worker        m = M()
1318*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3)
1319*da0073e9SAndroid Build Coastguard Worker        ref_out = m(input)
1320*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(m)
1321*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
1322*da0073e9SAndroid Build Coastguard Worker        out = gm(input)
1323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker    def test_torch_op_overloads(self):
1326*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1327*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
1328*da0073e9SAndroid Build Coastguard Worker                b = torch.ops.aten.add.Tensor(a, a)
1329*da0073e9SAndroid Build Coastguard Worker                return b
1330*da0073e9SAndroid Build Coastguard Worker        m = M()
1331*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3)
1332*da0073e9SAndroid Build Coastguard Worker        ref_out = m(input)
1333*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(m)
1334*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
1335*da0073e9SAndroid Build Coastguard Worker        out = gm(input)
1336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
1339*da0073e9SAndroid Build Coastguard Worker            if node.op == 'call_function':
1340*da0073e9SAndroid Build Coastguard Worker                assert isinstance(node.target, torch._ops.OpOverload)
1341*da0073e9SAndroid Build Coastguard Worker                assert node.target.__name__ == 'add.Tensor'
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker    def test_pickle_torch_custom_ops(self):
1344*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1345*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
1346*da0073e9SAndroid Build Coastguard Worker                b = torch.ops.aten.sigmoid(a)
1347*da0073e9SAndroid Build Coastguard Worker                c = torch.ops.aten.cat([a, b])
1348*da0073e9SAndroid Build Coastguard Worker                return torch.ops.aten.cat((c, c))
1349*da0073e9SAndroid Build Coastguard Worker        m = M()
1350*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3)
1351*da0073e9SAndroid Build Coastguard Worker        ref_out = m(input)
1352*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(m)
1353*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
1354*da0073e9SAndroid Build Coastguard Worker        pickled = pickle.dumps(gm)
1355*da0073e9SAndroid Build Coastguard Worker        loaded = pickle.loads(pickled)
1356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(input), gm(input))
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    def test_pretty_print(self):
1359*da0073e9SAndroid Build Coastguard Worker        st = SimpleTest()
1360*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(st)
1361*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
1362*da0073e9SAndroid Build Coastguard Worker        printed = str(traced)
1363*da0073e9SAndroid Build Coastguard Worker        assert 'SimpleTest()' in printed
1364*da0073e9SAndroid Build Coastguard Worker        assert 'torch.relu' in printed
1365*da0073e9SAndroid Build Coastguard Worker
1366*da0073e9SAndroid Build Coastguard Worker    def test_pretty_print_graph(self):
1367*da0073e9SAndroid Build Coastguard Worker        class KwargPrintTest(torch.nn.Module):
1368*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1369*da0073e9SAndroid Build Coastguard Worker                return torch.squeeze(x + 3.0, dim=2)
1370*da0073e9SAndroid Build Coastguard Worker        st = KwargPrintTest()
1371*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(st)
1372*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
1373*da0073e9SAndroid Build Coastguard Worker        stringed = str(traced.graph)
1374*da0073e9SAndroid Build Coastguard Worker        for s in ['args', 'kwargs', 'num_users']:
1375*da0073e9SAndroid Build Coastguard Worker            assert s in stringed
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker    def test_custom_proxy_type(self):
1378*da0073e9SAndroid Build Coastguard Worker        class TensorPair:
1379*da0073e9SAndroid Build Coastguard Worker            def __init__(self, left, right):
1380*da0073e9SAndroid Build Coastguard Worker                self.left, self.right = left, right
1381*da0073e9SAndroid Build Coastguard Worker
1382*da0073e9SAndroid Build Coastguard Worker            def add(self, other):
1383*da0073e9SAndroid Build Coastguard Worker                l = self.left + other.left
1384*da0073e9SAndroid Build Coastguard Worker                r = self.right + other.right
1385*da0073e9SAndroid Build Coastguard Worker                return TensorPair(l, r)
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker            def mul(self, other):
1388*da0073e9SAndroid Build Coastguard Worker                l = self.left * other.left
1389*da0073e9SAndroid Build Coastguard Worker                r = self.right * other.right
1390*da0073e9SAndroid Build Coastguard Worker                return TensorPair(l, r)
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker        def use_tensor_pair(x : TensorPair, y : TensorPair):
1393*da0073e9SAndroid Build Coastguard Worker            s = x.add(y)
1394*da0073e9SAndroid Build Coastguard Worker            return s.mul(x)
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1397*da0073e9SAndroid Build Coastguard Worker        y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker        ref_out = use_tensor_pair(x, y)
1400*da0073e9SAndroid Build Coastguard Worker
1401*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(use_tensor_pair)
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        traced_out = traced(x, y)
1404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.left, ref_out.left)
1405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.right, ref_out.right)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker    def test_custom_proxy_type_literal(self):
1408*da0073e9SAndroid Build Coastguard Worker        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1409*da0073e9SAndroid Build Coastguard Worker            def __init__(self, left, right):
1410*da0073e9SAndroid Build Coastguard Worker                self.left, self.right = left, right
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker            def add(self, other):
1413*da0073e9SAndroid Build Coastguard Worker                l = self.left + other.left
1414*da0073e9SAndroid Build Coastguard Worker                r = self.right + other.right
1415*da0073e9SAndroid Build Coastguard Worker                return TensorPair(l, r)
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker            def mul(self, other):
1418*da0073e9SAndroid Build Coastguard Worker                l = self.left * other.left
1419*da0073e9SAndroid Build Coastguard Worker                r = self.right * other.right
1420*da0073e9SAndroid Build Coastguard Worker                return TensorPair(l, r)
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Worker        def use_tensor_pair_literal(x : TensorPair):
1423*da0073e9SAndroid Build Coastguard Worker            s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
1424*da0073e9SAndroid Build Coastguard Worker            return s.mul(x)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker        ref_out = use_tensor_pair_literal(x)
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(use_tensor_pair_literal)
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        traced_out = traced(x)
1433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.left, ref_out.left)
1434*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.right, ref_out.right)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker    def test_custom_proxy_dynamic_value(self):
1437*da0073e9SAndroid Build Coastguard Worker        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1438*da0073e9SAndroid Build Coastguard Worker            def __init__(self, left, right):
1439*da0073e9SAndroid Build Coastguard Worker                self.left, self.right = left, right
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker            def add(self, other):
1442*da0073e9SAndroid Build Coastguard Worker                l = self.left + other.left
1443*da0073e9SAndroid Build Coastguard Worker                r = self.right + other.right
1444*da0073e9SAndroid Build Coastguard Worker                return TensorPair(l, r)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker            def mul(self, other):
1447*da0073e9SAndroid Build Coastguard Worker                l = self.left * other.left
1448*da0073e9SAndroid Build Coastguard Worker                r = self.right * other.right
1449*da0073e9SAndroid Build Coastguard Worker                return TensorPair(l, r)
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker        def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
1452*da0073e9SAndroid Build Coastguard Worker            s = x.add(TensorPair(y, y))
1453*da0073e9SAndroid Build Coastguard Worker            return s.mul(x)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1456*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 3)
1457*da0073e9SAndroid Build Coastguard Worker        ref_out = use_tensor_pair_ctor(x, y)
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(use_tensor_pair_ctor)
1460*da0073e9SAndroid Build Coastguard Worker
1461*da0073e9SAndroid Build Coastguard Worker        traced_out = traced(x, y)
1462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.left, ref_out.left)
1463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.right, ref_out.right)
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    def test_custom_proxy_input_dependent_control_flow(self):
1466*da0073e9SAndroid Build Coastguard Worker        class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
1467*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inp):
1468*da0073e9SAndroid Build Coastguard Worker                if inp.sum() == 0:
1469*da0073e9SAndroid Build Coastguard Worker                    self.is_zero = True
1470*da0073e9SAndroid Build Coastguard Worker                    self.tensor = torch.tensor([])
1471*da0073e9SAndroid Build Coastguard Worker                else:
1472*da0073e9SAndroid Build Coastguard Worker                    self.is_zero = False
1473*da0073e9SAndroid Build Coastguard Worker                    self.tensor = inp
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker            def add(self, other):
1476*da0073e9SAndroid Build Coastguard Worker                if self.is_zero:
1477*da0073e9SAndroid Build Coastguard Worker                    return ZeroTensor(other.tensor)
1478*da0073e9SAndroid Build Coastguard Worker                elif other.is_zero:
1479*da0073e9SAndroid Build Coastguard Worker                    return self
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker        def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
1482*da0073e9SAndroid Build Coastguard Worker            return ZeroTensor(x + y)
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker        x, y = torch.randn(5, 3), torch.randn(5, 3)
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker        ref_out = use_zero_tensor(x, y)
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(use_zero_tensor)
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker        traced_out = traced(x, y)
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.is_zero, ref_out.is_zero)
1493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out.tensor, ref_out.tensor)
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker    def test_graph_fns(self):
1496*da0073e9SAndroid Build Coastguard Worker        g = Graph()
1497*da0073e9SAndroid Build Coastguard Worker        a = g.placeholder('a')
1498*da0073e9SAndroid Build Coastguard Worker        b = g.call_module('linear', (a,))
1499*da0073e9SAndroid Build Coastguard Worker        c = g.get_attr('bias')
1500*da0073e9SAndroid Build Coastguard Worker        d = g.call_method('add', (b, c))
1501*da0073e9SAndroid Build Coastguard Worker        e = g.call_function(torch.sin, (d,))
1502*da0073e9SAndroid Build Coastguard Worker        g.output(e)
1503*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Module()
1504*da0073e9SAndroid Build Coastguard Worker        mod.linear = torch.nn.Linear(3, 4)
1505*da0073e9SAndroid Build Coastguard Worker        mod.bias = torch.rand(4)
1506*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(mod, g)
1507*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
1508*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3)
1509*da0073e9SAndroid Build Coastguard Worker        r = gm(input)
1510*da0073e9SAndroid Build Coastguard Worker        ref = torch.sin(mod.linear(input) + mod.bias)
1511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, ref)
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker    def test_remove_uses(self):
1514*da0073e9SAndroid Build Coastguard Worker        g : torch.fx.Graph = Graph()
1515*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = g.placeholder('x')
1516*da0073e9SAndroid Build Coastguard Worker        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1517*da0073e9SAndroid Build Coastguard Worker        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1518*da0073e9SAndroid Build Coastguard Worker        g.output(neg)
1519*da0073e9SAndroid Build Coastguard Worker
1520*da0073e9SAndroid Build Coastguard Worker        neg.replace_all_uses_with(relu)
1521*da0073e9SAndroid Build Coastguard Worker        g.erase_node(neg)
1522*da0073e9SAndroid Build Coastguard Worker
1523*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(neg not in relu.users)
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker    def test_remove_uses_with_custom_filter(self):
1526*da0073e9SAndroid Build Coastguard Worker        g : torch.fx.Graph = Graph()
1527*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = g.placeholder('x')
1528*da0073e9SAndroid Build Coastguard Worker        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1529*da0073e9SAndroid Build Coastguard Worker        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1530*da0073e9SAndroid Build Coastguard Worker        g.output(neg)
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker        neg.replace_all_uses_with(relu, lambda x: x != neg)
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(neg in relu.users)
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker    def test_nonetype_annotation(self):
1537*da0073e9SAndroid Build Coastguard Worker        eb = torch.nn.EmbeddingBag(3, 4)
1538*da0073e9SAndroid Build Coastguard Worker        symbolic_trace(eb)
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker    def test_pickle_nonetype_annotation(self):
1541*da0073e9SAndroid Build Coastguard Worker        eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
1542*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(eb)
1543*da0073e9SAndroid Build Coastguard Worker        pickled = pickle.dumps(traced)
1544*da0073e9SAndroid Build Coastguard Worker        loaded = pickle.loads(pickled)
1545*da0073e9SAndroid Build Coastguard Worker        loaded.graph.lint()
1546*da0073e9SAndroid Build Coastguard Worker        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1547*da0073e9SAndroid Build Coastguard Worker        offsets = torch.LongTensor([0, 4])
1548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded(input, offsets), traced(input, offsets))
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker    def test_return_tuple(self):
1551*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1552*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1553*da0073e9SAndroid Build Coastguard Worker                return (x, x + x)
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Worker        original = M()
1556*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(original)
1557*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker    def test_construct_root_dict(self):
1560*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
1561*da0073e9SAndroid Build Coastguard Worker        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1562*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1563*da0073e9SAndroid Build Coastguard Worker        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1564*da0073e9SAndroid Build Coastguard Worker        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1565*da0073e9SAndroid Build Coastguard Worker        graph.output(d)
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker        linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
1568*da0073e9SAndroid Build Coastguard Worker        add_param : torch.Tensor = torch.rand(3, 4)
1569*da0073e9SAndroid Build Coastguard Worker        gm : torch.fx.GraphModule = torch.fx.GraphModule(
1570*da0073e9SAndroid Build Coastguard Worker            {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
1571*da0073e9SAndroid Build Coastguard Worker        gm.graph.lint()
1572*da0073e9SAndroid Build Coastguard Worker
1573*da0073e9SAndroid Build Coastguard Worker        assert 'self.foo.bar.baz' in gm.code
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker        x : torch.Tensor = torch.rand(3, 3)
1576*da0073e9SAndroid Build Coastguard Worker        out : torch.Tensor = gm(x)
1577*da0073e9SAndroid Build Coastguard Worker        ref_out : torch.Tensor = linear_mod(x) + add_param
1578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
1579*da0073e9SAndroid Build Coastguard Worker
1580*da0073e9SAndroid Build Coastguard Worker    def test_symbolic_trace_assert(self):
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker        class AssertsTensorShape(torch.nn.Module):
1583*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1584*da0073e9SAndroid Build Coastguard Worker                torch._assert(x.shape[1] > 4, "assert_foobar")
1585*da0073e9SAndroid Build Coastguard Worker                return x
1586*da0073e9SAndroid Build Coastguard Worker
1587*da0073e9SAndroid Build Coastguard Worker        m = AssertsTensorShape()
1588*da0073e9SAndroid Build Coastguard Worker        # verify traceability
1589*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
1590*da0073e9SAndroid Build Coastguard Worker        # verify assertion on traced model works correctly at runtime
1591*da0073e9SAndroid Build Coastguard Worker        traced(torch.rand(4, 5))
1592*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "assert_foobar"):
1593*da0073e9SAndroid Build Coastguard Worker            traced(torch.rand(4, 3))
1594*da0073e9SAndroid Build Coastguard Worker        # verify the symbolically traced module is scriptable
1595*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(m)
1596*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
1597*da0073e9SAndroid Build Coastguard Worker            ms(torch.rand(4, 3))
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker    def test_fx_create_arg(self):
1600*da0073e9SAndroid Build Coastguard Worker        class CustomArgObject:
1601*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x, y):
1602*da0073e9SAndroid Build Coastguard Worker                self.x = x
1603*da0073e9SAndroid Build Coastguard Worker                self.y = y
1604*da0073e9SAndroid Build Coastguard Worker
1605*da0073e9SAndroid Build Coastguard Worker            def __fx_create_arg__(self, tracer: torch.fx.Tracer):
1606*da0073e9SAndroid Build Coastguard Worker                return tracer.create_node(
1607*da0073e9SAndroid Build Coastguard Worker                    "call_function",
1608*da0073e9SAndroid Build Coastguard Worker                    CustomArgObject,
1609*da0073e9SAndroid Build Coastguard Worker                    args=(
1610*da0073e9SAndroid Build Coastguard Worker                        tracer.create_arg(self.x),
1611*da0073e9SAndroid Build Coastguard Worker                        tracer.create_arg(self.y),
1612*da0073e9SAndroid Build Coastguard Worker                    ),
1613*da0073e9SAndroid Build Coastguard Worker                    kwargs={},
1614*da0073e9SAndroid Build Coastguard Worker                )
1615*da0073e9SAndroid Build Coastguard Worker
1616*da0073e9SAndroid Build Coastguard Worker        class HasCustomArgObjectWhenLeaf(torch.nn.Module):
1617*da0073e9SAndroid Build Coastguard Worker            def forward(self, o: CustomArgObject):
1618*da0073e9SAndroid Build Coastguard Worker                # Not normally traceable; good reason to make
1619*da0073e9SAndroid Build Coastguard Worker                # this module a leaf.
1620*da0073e9SAndroid Build Coastguard Worker                for x in o.x:
1621*da0073e9SAndroid Build Coastguard Worker                    o.y += x
1622*da0073e9SAndroid Build Coastguard Worker                return o.y
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        class Root(torch.nn.Module):
1625*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1626*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1627*da0073e9SAndroid Build Coastguard Worker                self.inner = HasCustomArgObjectWhenLeaf()
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
1630*da0073e9SAndroid Build Coastguard Worker                o = CustomArgObject(x, y)
1631*da0073e9SAndroid Build Coastguard Worker                return self.inner(o)
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker        class CreateArgTracer(torch.fx.Tracer):
1634*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, m, module_qualified_name):
1635*da0073e9SAndroid Build Coastguard Worker                return type(m) is HasCustomArgObjectWhenLeaf
1636*da0073e9SAndroid Build Coastguard Worker
1637*da0073e9SAndroid Build Coastguard Worker        m = Root()
1638*da0073e9SAndroid Build Coastguard Worker        graph = CreateArgTracer().trace(m)
1639*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.GraphModule(m, graph)
1640*da0073e9SAndroid Build Coastguard Worker        assert "CustomArgObject(" in gm.code
1641*da0073e9SAndroid Build Coastguard Worker
1642*da0073e9SAndroid Build Coastguard Worker    def test_trace_fn_constant(self):
1643*da0073e9SAndroid Build Coastguard Worker        some_constant = torch.rand(3, 4)
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker        def add_const(x):
1646*da0073e9SAndroid Build Coastguard Worker            return some_constant + x
1647*da0073e9SAndroid Build Coastguard Worker
1648*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(add_const)
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
1651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(input), add_const(input))
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker    def test_copy_no_remap(self):
1654*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(SimpleTest())
1655*da0073e9SAndroid Build Coastguard Worker        g = traced.graph
1656*da0073e9SAndroid Build Coastguard Worker        copied = torch.fx.Graph()
1657*da0073e9SAndroid Build Coastguard Worker        for node in g.nodes:
1658*da0073e9SAndroid Build Coastguard Worker            copied.node_copy(node)
1659*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
1660*da0073e9SAndroid Build Coastguard Worker            copied.lint()
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker    def test_wrong_topo(self):
1663*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
1664*da0073e9SAndroid Build Coastguard Worker        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1665*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1666*da0073e9SAndroid Build Coastguard Worker        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1667*da0073e9SAndroid Build Coastguard Worker        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1668*da0073e9SAndroid Build Coastguard Worker        graph.output(d)
1669*da0073e9SAndroid Build Coastguard Worker        nodes = list(graph.nodes)
1670*da0073e9SAndroid Build Coastguard Worker        nodes[3].append(nodes[2])
1671*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
1672*da0073e9SAndroid Build Coastguard Worker            graph.lint()
1673*da0073e9SAndroid Build Coastguard Worker
1674*da0073e9SAndroid Build Coastguard Worker    def test_wrong_target_type(self):
1675*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
1676*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1677*da0073e9SAndroid Build Coastguard Worker            n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
1678*da0073e9SAndroid Build Coastguard Worker                              args=(), kwargs={})
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker    def test_example_shape_prop(self):
1681*da0073e9SAndroid Build Coastguard Worker        class TestCase(torch.nn.Module):
1682*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1683*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1684*da0073e9SAndroid Build Coastguard Worker                self.attr = torch.randn(3, 4)
1685*da0073e9SAndroid Build Coastguard Worker                self.submod = torch.nn.Linear(4, 4)
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1688*da0073e9SAndroid Build Coastguard Worker                return torch.neg(self.submod(x.relu() + self.attr))
1689*da0073e9SAndroid Build Coastguard Worker        tc = TestCase()
1690*da0073e9SAndroid Build Coastguard Worker        tc_traced = symbolic_trace(tc)
1691*da0073e9SAndroid Build Coastguard Worker        ref_out = tc_traced(torch.rand(3, 4))
1692*da0073e9SAndroid Build Coastguard Worker        shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
1693*da0073e9SAndroid Build Coastguard Worker
1694*da0073e9SAndroid Build Coastguard Worker        # Make sure we're testing all opcodes
1695*da0073e9SAndroid Build Coastguard Worker        opcodes = set()
1696*da0073e9SAndroid Build Coastguard Worker        output_shape : Optional[torch.Shape] = None
1697*da0073e9SAndroid Build Coastguard Worker        output_stride : Optional[Tuple[int]] = None
1698*da0073e9SAndroid Build Coastguard Worker        for node in tc_traced.graph.nodes:
1699*da0073e9SAndroid Build Coastguard Worker            opcodes.add(node.op)
1700*da0073e9SAndroid Build Coastguard Worker            if node.op == 'output':
1701*da0073e9SAndroid Build Coastguard Worker                output_shape = node.args[0].meta['tensor_meta'].shape
1702*da0073e9SAndroid Build Coastguard Worker                output_stride = node.args[0].meta['tensor_meta'].stride
1703*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
1704*da0073e9SAndroid Build Coastguard Worker                                   'call_module', 'output'})
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker        # Test shape propagation and make sure results match actual
1707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_shape, ref_out.shape)
1708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_stride, ref_out.stride())
1709*da0073e9SAndroid Build Coastguard Worker
1710*da0073e9SAndroid Build Coastguard Worker    def test_shape_prop_layout(self):
1711*da0073e9SAndroid Build Coastguard Worker        class ConvTest(torch.nn.Module):
1712*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1713*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1714*da0073e9SAndroid Build Coastguard Worker                self.conv_mod = torch.nn.Conv2d(5, 5, 3)
1715*da0073e9SAndroid Build Coastguard Worker
1716*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1717*da0073e9SAndroid Build Coastguard Worker                return self.conv_mod(x)
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker        # contiguous layout
1720*da0073e9SAndroid Build Coastguard Worker        test_mod = ConvTest()
1721*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(test_mod)
1722*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, 224, 224)
1723*da0073e9SAndroid Build Coastguard Worker        shape_prop.ShapeProp(traced).propagate(x)
1724*da0073e9SAndroid Build Coastguard Worker
1725*da0073e9SAndroid Build Coastguard Worker        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1726*da0073e9SAndroid Build Coastguard Worker                   for node in traced.graph.nodes)
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker        x_channels_last = x.contiguous(memory_format=torch.channels_last)
1729*da0073e9SAndroid Build Coastguard Worker        traced.to(memory_format=torch.channels_last)
1730*da0073e9SAndroid Build Coastguard Worker        shape_prop.ShapeProp(traced).propagate(x_channels_last)
1731*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
1732*da0073e9SAndroid Build Coastguard Worker            # NB: the implementation of conv may not preserve the memory format,
1733*da0073e9SAndroid Build Coastguard Worker            # unfortunately. The best we can do is just check that the placeholder
1734*da0073e9SAndroid Build Coastguard Worker            # node is channels-last
1735*da0073e9SAndroid Build Coastguard Worker            if node.op in {'placeholder'}:
1736*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker    def test_shape_prop_aggregate(self):
1739*da0073e9SAndroid Build Coastguard Worker        class ReturnTwo(torch.nn.Module):
1740*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1741*da0073e9SAndroid Build Coastguard Worker                return (3, torch.sum(x))
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        class UnderTest(torch.nn.Module):
1744*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1745*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1746*da0073e9SAndroid Build Coastguard Worker                self.rt = ReturnTwo()
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1749*da0073e9SAndroid Build Coastguard Worker                return self.rt(x)
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker        ut = UnderTest()
1752*da0073e9SAndroid Build Coastguard Worker
1753*da0073e9SAndroid Build Coastguard Worker        class RTTracer(torch.fx.Tracer):
1754*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, m, module_qualified_name):
1755*da0073e9SAndroid Build Coastguard Worker                return type(m) is ReturnTwo
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Worker        graph = RTTracer().trace(ut)
1758*da0073e9SAndroid Build Coastguard Worker        mod = torch.fx.GraphModule(ut, graph)
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker        shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
1761*da0073e9SAndroid Build Coastguard Worker
1762*da0073e9SAndroid Build Coastguard Worker        for node in mod.graph.nodes:
1763*da0073e9SAndroid Build Coastguard Worker            if node.op == 'call_module':
1764*da0073e9SAndroid Build Coastguard Worker                assert 'tensor_meta' in node.meta
1765*da0073e9SAndroid Build Coastguard Worker                tensor_meta = node.meta['tensor_meta']
1766*da0073e9SAndroid Build Coastguard Worker                assert tensor_meta[0] == 3
1767*da0073e9SAndroid Build Coastguard Worker                assert tensor_meta[1].shape == torch.Size([])
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker    def test_shape_prop_layout_3d(self):
1770*da0073e9SAndroid Build Coastguard Worker        class ConvTest3d(torch.nn.Module):
1771*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1772*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1773*da0073e9SAndroid Build Coastguard Worker                self.conv_mod = torch.nn.Conv3d(5, 5, 3)
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1776*da0073e9SAndroid Build Coastguard Worker                return self.conv_mod(x)
1777*da0073e9SAndroid Build Coastguard Worker
1778*da0073e9SAndroid Build Coastguard Worker        test_mod_3d = ConvTest3d()
1779*da0073e9SAndroid Build Coastguard Worker        traced_3d = symbolic_trace(test_mod_3d)
1780*da0073e9SAndroid Build Coastguard Worker        x_3d = torch.randn(5, 5, 224, 224, 15)
1781*da0073e9SAndroid Build Coastguard Worker        shape_prop.ShapeProp(traced_3d).propagate(x_3d)
1782*da0073e9SAndroid Build Coastguard Worker        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1783*da0073e9SAndroid Build Coastguard Worker                   for node in traced_3d.graph.nodes)
1784*da0073e9SAndroid Build Coastguard Worker
1785*da0073e9SAndroid Build Coastguard Worker        x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
1786*da0073e9SAndroid Build Coastguard Worker        traced_3d.to(memory_format=torch.channels_last_3d)
1787*da0073e9SAndroid Build Coastguard Worker        shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
1788*da0073e9SAndroid Build Coastguard Worker        for node in traced_3d.graph.nodes:
1789*da0073e9SAndroid Build Coastguard Worker            # NB: the implementation of conv may not preserve the memory format,
1790*da0073e9SAndroid Build Coastguard Worker            # unfortunately. The best we can do is just check that the placeholder
1791*da0073e9SAndroid Build Coastguard Worker            # node is channels-last
1792*da0073e9SAndroid Build Coastguard Worker            if node.op in {'placeholder'}:
1793*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
1794*da0073e9SAndroid Build Coastguard Worker
1795*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_stack(self):
1796*da0073e9SAndroid Build Coastguard Worker        class SubModule(torch.nn.Module):
1797*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1798*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1799*da0073e9SAndroid Build Coastguard Worker                self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1802*da0073e9SAndroid Build Coastguard Worker                return self.conv_mod(x)
1803*da0073e9SAndroid Build Coastguard Worker
1804*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1805*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1806*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1807*da0073e9SAndroid Build Coastguard Worker                self.sub_mod = SubModule()
1808*da0073e9SAndroid Build Coastguard Worker
1809*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1810*da0073e9SAndroid Build Coastguard Worker                return self.sub_mod(x)
1811*da0073e9SAndroid Build Coastguard Worker
1812*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
1813*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
1814*da0073e9SAndroid Build Coastguard Worker
1815*da0073e9SAndroid Build Coastguard Worker        mod_stack = {}
1816*da0073e9SAndroid Build Coastguard Worker        expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
1817*da0073e9SAndroid Build Coastguard Worker                          ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
1818*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
1819*da0073e9SAndroid Build Coastguard Worker            mod_stack = node.meta.get('nn_module_stack', {})
1820*da0073e9SAndroid Build Coastguard Worker            if mod_stack:
1821*da0073e9SAndroid Build Coastguard Worker                break
1822*da0073e9SAndroid Build Coastguard Worker        stack_list = list(mod_stack.items())
1823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stack_list, expected_stack)
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker    def test_transformer_preserves_nn_module_stack_for_get_attr(self):
1826*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1827*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1828*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1829*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.ones(1, 1))
1830*da0073e9SAndroid Build Coastguard Worker
1831*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1832*da0073e9SAndroid Build Coastguard Worker                return self.weight + x
1833*da0073e9SAndroid Build Coastguard Worker
1834*da0073e9SAndroid Build Coastguard Worker        tracer = torch.fx.Tracer()
1835*da0073e9SAndroid Build Coastguard Worker        graph = tracer.trace(M())
1836*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(tracer.root, graph)
1837*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
1838*da0073e9SAndroid Build Coastguard Worker            if node.op == 'get_attr':
1839*da0073e9SAndroid Build Coastguard Worker                node.meta["nn_module_stack"] = "self"
1840*da0073e9SAndroid Build Coastguard Worker                node.meta["stack_trace"] = "stack_trace"
1841*da0073e9SAndroid Build Coastguard Worker                node.meta["source_fn_stack"] = "source_fn_stack"
1842*da0073e9SAndroid Build Coastguard Worker        new_gm = Transformer(gm).transform()
1843*da0073e9SAndroid Build Coastguard Worker        for node in new_gm.graph.nodes:
1844*da0073e9SAndroid Build Coastguard Worker            if node.op == 'get_attr':
1845*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(node.meta["nn_module_stack"], "self")
1846*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(node.meta["stack_trace"], "stack_trace")
1847*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Worker    def test_interpreter(self):
1850*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1851*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1852*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1853*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
1854*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1857*da0073e9SAndroid Build Coastguard Worker                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1858*da0073e9SAndroid Build Coastguard Worker
1859*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
1860*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Worker        interpreter = Interpreter(gm)
1863*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
1864*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(interpreter.run(input), gm(input))
1865*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(interpreter.run(input), m(input))
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_other_graph(self):
1868*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1869*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1870*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1871*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
1872*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1875*da0073e9SAndroid Build Coastguard Worker                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
1878*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
1879*da0073e9SAndroid Build Coastguard Worker
1880*da0073e9SAndroid Build Coastguard Worker        interpreter = Interpreter(gm, graph=gm.graph)
1881*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
1882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(interpreter.run(input), gm(input))
1883*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(interpreter.run(input), m(input))
1884*da0073e9SAndroid Build Coastguard Worker
1885*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_run_node_override(self):
1886*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1887*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1888*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1889*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
1890*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
1891*da0073e9SAndroid Build Coastguard Worker
1892*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1893*da0073e9SAndroid Build Coastguard Worker                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
1896*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
1897*da0073e9SAndroid Build Coastguard Worker
1898*da0073e9SAndroid Build Coastguard Worker        class RunNodeInterpreter(Interpreter):
1899*da0073e9SAndroid Build Coastguard Worker            def __init__(self, module):
1900*da0073e9SAndroid Build Coastguard Worker                super().__init__(module)
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker            def run_node(self, n : Node) -> Any:
1903*da0073e9SAndroid Build Coastguard Worker                result = super().run_node(n)
1904*da0073e9SAndroid Build Coastguard Worker                n.cached_value = result
1905*da0073e9SAndroid Build Coastguard Worker                return result
1906*da0073e9SAndroid Build Coastguard Worker
1907*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
1908*da0073e9SAndroid Build Coastguard Worker        RunNodeInterpreter(gm).run(input)
1909*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
1910*da0073e9SAndroid Build Coastguard Worker            assert hasattr(node, 'cached_value')
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_onthefly_swap(self):
1913*da0073e9SAndroid Build Coastguard Worker
1914*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1915*da0073e9SAndroid Build Coastguard Worker            return torch.sigmoid(x).neg()
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(fn)
1918*da0073e9SAndroid Build Coastguard Worker
1919*da0073e9SAndroid Build Coastguard Worker        class NegSigmSwapInterpreter(Interpreter):
1920*da0073e9SAndroid Build Coastguard Worker            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1921*da0073e9SAndroid Build Coastguard Worker                if target == torch.sigmoid:
1922*da0073e9SAndroid Build Coastguard Worker                    return torch.neg(*args, **kwargs)
1923*da0073e9SAndroid Build Coastguard Worker                return super().call_function(n)  # noqa: F821
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1926*da0073e9SAndroid Build Coastguard Worker                if target == 'neg':
1927*da0073e9SAndroid Build Coastguard Worker                    call_self, *args_tail = args
1928*da0073e9SAndroid Build Coastguard Worker                    return call_self.sigmoid(*args_tail, **kwargs)
1929*da0073e9SAndroid Build Coastguard Worker                return super().call_method(n)  # noqa: F821
1930*da0073e9SAndroid Build Coastguard Worker
1931*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
1932*da0073e9SAndroid Build Coastguard Worker        result = NegSigmSwapInterpreter(gm).run(input)
1933*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.neg(input).sigmoid())
1934*da0073e9SAndroid Build Coastguard Worker
1935*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_partial_eval(self):
1936*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1937*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1938*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1939*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
1940*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
1941*da0073e9SAndroid Build Coastguard Worker
1942*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1943*da0073e9SAndroid Build Coastguard Worker                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1944*da0073e9SAndroid Build Coastguard Worker
1945*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(MyModule())
1946*da0073e9SAndroid Build Coastguard Worker        interp = Interpreter(gm)
1947*da0073e9SAndroid Build Coastguard Worker        env = {}
1948*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
1949*da0073e9SAndroid Build Coastguard Worker            if node.op == 'call_module' and node.target == 'linear':
1950*da0073e9SAndroid Build Coastguard Worker                env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1951*da0073e9SAndroid Build Coastguard Worker                break
1952*da0073e9SAndroid Build Coastguard Worker        assert len(env) == 1
1953*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
1954*da0073e9SAndroid Build Coastguard Worker        result = interp.run(x, initial_env=env)
1955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_star_args(self):
1958*da0073e9SAndroid Build Coastguard Worker        def with_star_args(x, *args):
1959*da0073e9SAndroid Build Coastguard Worker            return x + args[0]
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(with_star_args)
1962*da0073e9SAndroid Build Coastguard Worker        interp = Interpreter(gm)
1963*da0073e9SAndroid Build Coastguard Worker        result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1964*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.ones(3, 4) * 2.0)
1965*da0073e9SAndroid Build Coastguard Worker
1966*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
1967*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_noop_resnet18(self):
1968*da0073e9SAndroid Build Coastguard Worker        rn18 = torchvision_models.resnet18()
1969*da0073e9SAndroid Build Coastguard Worker        transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
1970*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(5, 3, 224, 224)
1971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(transformed(inp), rn18(inp))
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
1974*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_gc_values(self):
1975*da0073e9SAndroid Build Coastguard Worker        rn18 = torchvision_models.resnet18()
1976*da0073e9SAndroid Build Coastguard Worker        interp = Interpreter(symbolic_trace(rn18))
1977*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(5, 3, 224, 224)
1978*da0073e9SAndroid Build Coastguard Worker        out = interp.run(inp)
1979*da0073e9SAndroid Build Coastguard Worker        env_key_names = {n.name for n in interp.env.keys()}
1980*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(env_key_names, {'output'})
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_default_args(self):
1983*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
1984*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y=3.14159):
1985*da0073e9SAndroid Build Coastguard Worker                return x + y
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker        model = Model()
1988*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(model)
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker        interp = Interpreter(gm)
1991*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 3)
1992*da0073e9SAndroid Build Coastguard Worker        out = interp.run(x)
1993*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(out, x + 3.14159)
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_not_enough_args(self):
1996*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
1997*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
1998*da0073e9SAndroid Build Coastguard Worker                return x + y
1999*da0073e9SAndroid Build Coastguard Worker
2000*da0073e9SAndroid Build Coastguard Worker        model = Model()
2001*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(model)
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker        interp = Interpreter(gm)
2004*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 3)
2005*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
2006*da0073e9SAndroid Build Coastguard Worker                                    'Expected positional argument for parameter y, but one was not passed in'):
2007*da0073e9SAndroid Build Coastguard Worker            out = interp.run(x)
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker    def test_transformer_noop(self):
2010*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2011*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2012*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2013*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
2014*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
2015*da0073e9SAndroid Build Coastguard Worker
2016*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2017*da0073e9SAndroid Build Coastguard Worker                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
2018*da0073e9SAndroid Build Coastguard Worker
2019*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
2020*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
2021*da0073e9SAndroid Build Coastguard Worker
2022*da0073e9SAndroid Build Coastguard Worker        new_gm = Transformer(gm).transform()
2023*da0073e9SAndroid Build Coastguard Worker
2024*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
2025*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_gm(input), gm(input))
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker    def test_transformer_op_swap(self):
2028*da0073e9SAndroid Build Coastguard Worker
2029*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2030*da0073e9SAndroid Build Coastguard Worker            return torch.sigmoid(x).neg()
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(fn)
2033*da0073e9SAndroid Build Coastguard Worker
2034*da0073e9SAndroid Build Coastguard Worker        class NegSigmSwapXformer(Transformer):
2035*da0073e9SAndroid Build Coastguard Worker            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2036*da0073e9SAndroid Build Coastguard Worker                if target == torch.sigmoid:
2037*da0073e9SAndroid Build Coastguard Worker                    return torch.neg(*args, **kwargs)
2038*da0073e9SAndroid Build Coastguard Worker                return super().call_function(n)  # noqa: F821
2039*da0073e9SAndroid Build Coastguard Worker
2040*da0073e9SAndroid Build Coastguard Worker            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2041*da0073e9SAndroid Build Coastguard Worker                if target == 'neg':
2042*da0073e9SAndroid Build Coastguard Worker                    call_self, *args_tail = args
2043*da0073e9SAndroid Build Coastguard Worker                    return call_self.sigmoid(*args_tail, **kwargs)
2044*da0073e9SAndroid Build Coastguard Worker                return super().call_method(n)  # noqa: F821
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker        transformed = NegSigmSwapXformer(gm).transform()
2047*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
2048*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(transformed(input), torch.neg(input).sigmoid())
2049*da0073e9SAndroid Build Coastguard Worker
2050*da0073e9SAndroid Build Coastguard Worker    def test_transformer_multi_outputs(self):
2051*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2052*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2053*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2054*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
2055*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2058*da0073e9SAndroid Build Coastguard Worker                x = x + self.param
2059*da0073e9SAndroid Build Coastguard Worker                out = self.linear(x)
2060*da0073e9SAndroid Build Coastguard Worker                return x, out
2061*da0073e9SAndroid Build Coastguard Worker
2062*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
2063*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker        new_gm = Transformer(gm).transform()
2066*da0073e9SAndroid Build Coastguard Worker
2067*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
2068*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_gm(input), gm(input))
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker    def test_fn_type_annotations(self):
2071*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2072*da0073e9SAndroid Build Coastguard Worker            def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
2073*da0073e9SAndroid Build Coastguard Worker                return {'a': p.x + p.y + z + i}
2074*da0073e9SAndroid Build Coastguard Worker
2075*da0073e9SAndroid Build Coastguard Worker        foo_scripted = torch.jit.script(Foo())
2076*da0073e9SAndroid Build Coastguard Worker        foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2077*da0073e9SAndroid Build Coastguard Worker
2078*da0073e9SAndroid Build Coastguard Worker        fxed = symbolic_trace(Foo())
2079*da0073e9SAndroid Build Coastguard Worker        fxed_scripted = torch.jit.script(fxed)
2080*da0073e9SAndroid Build Coastguard Worker        fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2081*da0073e9SAndroid Build Coastguard Worker
2082*da0073e9SAndroid Build Coastguard Worker    def test_fn_type_annotation_empty(self):
2083*da0073e9SAndroid Build Coastguard Worker        def forward(a : List[torch.Tensor]):
2084*da0073e9SAndroid Build Coastguard Worker            return a[0]
2085*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(symbolic_trace(forward))
2086*da0073e9SAndroid Build Coastguard Worker
2087*da0073e9SAndroid Build Coastguard Worker    def test_wrapped_method(self):
2088*da0073e9SAndroid Build Coastguard Worker        def wrap_with_relu(fn):
2089*da0073e9SAndroid Build Coastguard Worker            @functools.wraps(fn)
2090*da0073e9SAndroid Build Coastguard Worker            def wrapper(*args, **kwargs):
2091*da0073e9SAndroid Build Coastguard Worker                return torch.relu(fn(*args, **kwargs))
2092*da0073e9SAndroid Build Coastguard Worker            return wrapper
2093*da0073e9SAndroid Build Coastguard Worker
2094*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2095*da0073e9SAndroid Build Coastguard Worker            @wrap_with_relu
2096*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, w):
2097*da0073e9SAndroid Build Coastguard Worker                return torch.matmul(x, w)
2098*da0073e9SAndroid Build Coastguard Worker
2099*da0073e9SAndroid Build Coastguard Worker        f = Foo()
2100*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(f)
2101*da0073e9SAndroid Build Coastguard Worker        x, w = torch.rand(3, 4), torch.rand(4, 4)
2102*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker    def test_empty_graph_codegen(self):
2105*da0073e9SAndroid Build Coastguard Worker        graph = torch.fx.Graph()
2106*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(), None)
2108*da0073e9SAndroid Build Coastguard Worker
2109*da0073e9SAndroid Build Coastguard Worker    def test_sequential(self):
2110*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
2111*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(m)
2112*da0073e9SAndroid Build Coastguard Worker        gm_copy = copy.deepcopy(gm)
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker    def test_ctx_mgr(self):
2115*da0073e9SAndroid Build Coastguard Worker        @contextlib.contextmanager
2116*da0073e9SAndroid Build Coastguard Worker        def do_nothing():
2117*da0073e9SAndroid Build Coastguard Worker            yield
2118*da0073e9SAndroid Build Coastguard Worker
2119*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2120*da0073e9SAndroid Build Coastguard Worker            @do_nothing()
2121*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2122*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker        m = M()
2125*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (torch.rand(3, 4),))
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker    def test_typename_print(self):
2128*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2129*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2130*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
2131*da0073e9SAndroid Build Coastguard Worker                                              type_expr=List[float])
2132*da0073e9SAndroid Build Coastguard Worker        output : torch.fx.Node = graph.output(b)
2133*da0073e9SAndroid Build Coastguard Worker
2134*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('typing.List[float]' in str(graph))
2135*da0073e9SAndroid Build Coastguard Worker
2136*da0073e9SAndroid Build Coastguard Worker    def test_layout(self):
2137*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2138*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2139*da0073e9SAndroid Build Coastguard Worker                return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
2140*da0073e9SAndroid Build Coastguard Worker
2141*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(M())
2142*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(5, 9, 3, 4)
2143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(x), torch.zeros_like(x))
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis(self):
2146*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2147*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
2148*da0073e9SAndroid Build Coastguard Worker                return x + y[:, 1:10, ...]
2149*da0073e9SAndroid Build Coastguard Worker
2150*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(M())
2151*da0073e9SAndroid Build Coastguard Worker        x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
2152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker    def test_inf_nan(self):
2155*da0073e9SAndroid Build Coastguard Worker        class FooMod(torch.nn.Module):
2156*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2157*da0073e9SAndroid Build Coastguard Worker                return x + float('inf'), x + float('-inf'), x + float('nan')
2158*da0073e9SAndroid Build Coastguard Worker
2159*da0073e9SAndroid Build Coastguard Worker        fm = FooMod()
2160*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(fm, (torch.rand(3, 4),))
2161*da0073e9SAndroid Build Coastguard Worker
2162*da0073e9SAndroid Build Coastguard Worker    def test_inf_nan_kwds(self):
2163*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2164*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2165*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
2166*da0073e9SAndroid Build Coastguard Worker        c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
2167*da0073e9SAndroid Build Coastguard Worker        graph.output((b, c))
2168*da0073e9SAndroid Build Coastguard Worker
2169*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2170*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
2171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
2172*da0073e9SAndroid Build Coastguard Worker
2173*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_recursion_depth(self):
2174*da0073e9SAndroid Build Coastguard Worker        depth = sys.getrecursionlimit() + 20
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker        g = torch.fx.Graph()
2177*da0073e9SAndroid Build Coastguard Worker        x = g.placeholder('x')
2178*da0073e9SAndroid Build Coastguard Worker        for i in range(depth):
2179*da0073e9SAndroid Build Coastguard Worker            x = g.call_function(torch.relu, (x,))
2180*da0073e9SAndroid Build Coastguard Worker        g.output(x)
2181*da0073e9SAndroid Build Coastguard Worker
2182*da0073e9SAndroid Build Coastguard Worker        copied_graph = copy.deepcopy(g)
2183*da0073e9SAndroid Build Coastguard Worker
2184*da0073e9SAndroid Build Coastguard Worker        val_map = {}
2185*da0073e9SAndroid Build Coastguard Worker        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2186*da0073e9SAndroid Build Coastguard Worker            val_map[orig_node] = new_node
2187*da0073e9SAndroid Build Coastguard Worker
2188*da0073e9SAndroid Build Coastguard Worker        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2189*da0073e9SAndroid Build Coastguard Worker            orig_users = set(orig_node.users.keys())
2190*da0073e9SAndroid Build Coastguard Worker            orig_users_equiv = {val_map[u] for u in orig_users}
2191*da0073e9SAndroid Build Coastguard Worker            new_users = set(new_node.users.keys())
2192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(orig_users_equiv, new_users)
2193*da0073e9SAndroid Build Coastguard Worker
2194*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
2195*da0073e9SAndroid Build Coastguard Worker    def test_replace_uses(self):
2196*da0073e9SAndroid Build Coastguard Worker        rn18 = torchvision_models.resnet18()
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker        class LowerReluTracer(torch.fx.Tracer):
2199*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, m : torch.nn.Module, qualname : str):
2200*da0073e9SAndroid Build Coastguard Worker                if isinstance(m, torch.nn.ReLU):
2201*da0073e9SAndroid Build Coastguard Worker                    return False
2202*da0073e9SAndroid Build Coastguard Worker                return super().is_leaf_module(m, qualname)
2203*da0073e9SAndroid Build Coastguard Worker
2204*da0073e9SAndroid Build Coastguard Worker        rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
2205*da0073e9SAndroid Build Coastguard Worker
2206*da0073e9SAndroid Build Coastguard Worker        to_erase = []
2207*da0073e9SAndroid Build Coastguard Worker        for node in rn18_traced.graph.nodes:
2208*da0073e9SAndroid Build Coastguard Worker            if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
2209*da0073e9SAndroid Build Coastguard Worker                kwargs = node.kwargs.copy()
2210*da0073e9SAndroid Build Coastguard Worker                # Neg doesn't have in-place
2211*da0073e9SAndroid Build Coastguard Worker                kwargs.pop('inplace')
2212*da0073e9SAndroid Build Coastguard Worker                with rn18_traced.graph.inserting_before(node):
2213*da0073e9SAndroid Build Coastguard Worker                    new_node = rn18_traced.graph.call_function(
2214*da0073e9SAndroid Build Coastguard Worker                        the_function=torch.neg, args=node.args, kwargs=node.kwargs)
2215*da0073e9SAndroid Build Coastguard Worker                node.replace_all_uses_with(replace_with=new_node)
2216*da0073e9SAndroid Build Coastguard Worker                to_erase.append(node)
2217*da0073e9SAndroid Build Coastguard Worker
2218*da0073e9SAndroid Build Coastguard Worker        for node in to_erase:
2219*da0073e9SAndroid Build Coastguard Worker            rn18_traced.graph.erase_node(node)
2220*da0073e9SAndroid Build Coastguard Worker
2221*da0073e9SAndroid Build Coastguard Worker    def test_replace_input(self):
2222*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2223*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2224*da0073e9SAndroid Build Coastguard Worker        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2225*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2226*da0073e9SAndroid Build Coastguard Worker        output : torch.fx.Node = graph.output(b)
2227*da0073e9SAndroid Build Coastguard Worker
2228*da0073e9SAndroid Build Coastguard Worker        b.replace_input_with(x, y)
2229*da0073e9SAndroid Build Coastguard Worker
2230*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2231*da0073e9SAndroid Build Coastguard Worker
2232*da0073e9SAndroid Build Coastguard Worker        input_x = torch.randn(33, 44)
2233*da0073e9SAndroid Build Coastguard Worker        input_y = torch.randn(11, 22)
2234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
2235*da0073e9SAndroid Build Coastguard Worker
2236*da0073e9SAndroid Build Coastguard Worker    def test_insertion_point(self):
2237*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2238*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2239*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2240*da0073e9SAndroid Build Coastguard Worker        output : torch.fx.Node = graph.output(b)
2241*da0073e9SAndroid Build Coastguard Worker
2242*da0073e9SAndroid Build Coastguard Worker        with graph.inserting_before(b):
2243*da0073e9SAndroid Build Coastguard Worker            neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2244*da0073e9SAndroid Build Coastguard Worker            _, *relu_args = b.args
2245*da0073e9SAndroid Build Coastguard Worker            b.args = (neg, *relu_args)
2246*da0073e9SAndroid Build Coastguard Worker
2247*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2248*da0073e9SAndroid Build Coastguard Worker
2249*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(33, 44)
2250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2251*da0073e9SAndroid Build Coastguard Worker
2252*da0073e9SAndroid Build Coastguard Worker    def test_update_args_api(self):
2253*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2254*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2255*da0073e9SAndroid Build Coastguard Worker        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2256*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2257*da0073e9SAndroid Build Coastguard Worker        output : torch.fx.Node = graph.output(b)
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2260*da0073e9SAndroid Build Coastguard Worker        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2262*da0073e9SAndroid Build Coastguard Worker
2263*da0073e9SAndroid Build Coastguard Worker        b.update_arg(0, y)
2264*da0073e9SAndroid Build Coastguard Worker        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2265*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2266*da0073e9SAndroid Build Coastguard Worker
2267*da0073e9SAndroid Build Coastguard Worker    def test_update_kwargs_api(self):
2268*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2269*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2270*da0073e9SAndroid Build Coastguard Worker        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2271*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
2272*da0073e9SAndroid Build Coastguard Worker        output : torch.fx.Node = graph.output(b)
2273*da0073e9SAndroid Build Coastguard Worker
2274*da0073e9SAndroid Build Coastguard Worker        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2275*da0073e9SAndroid Build Coastguard Worker        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2277*da0073e9SAndroid Build Coastguard Worker
2278*da0073e9SAndroid Build Coastguard Worker        b.update_kwarg('input', y)
2279*da0073e9SAndroid Build Coastguard Worker        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2281*da0073e9SAndroid Build Coastguard Worker
2282*da0073e9SAndroid Build Coastguard Worker    def test_immutable_list_pytree_ops(self):
2283*da0073e9SAndroid Build Coastguard Worker        rand_tensor = torch.randn(5, 3)
2284*da0073e9SAndroid Build Coastguard Worker        l = immutable_list([3, [rand_tensor, 42]])
2285*da0073e9SAndroid Build Coastguard Worker
2286*da0073e9SAndroid Build Coastguard Worker        flattened, spec = pytree.tree_flatten(l)
2287*da0073e9SAndroid Build Coastguard Worker        assert flattened == [3, rand_tensor, 42]
2288*da0073e9SAndroid Build Coastguard Worker
2289*da0073e9SAndroid Build Coastguard Worker        unflattened = pytree.tree_unflatten(flattened, spec)
2290*da0073e9SAndroid Build Coastguard Worker        assert unflattened == l
2291*da0073e9SAndroid Build Coastguard Worker        assert isinstance(unflattened, immutable_list)
2292*da0073e9SAndroid Build Coastguard Worker
2293*da0073e9SAndroid Build Coastguard Worker    def test_immutable_dict_pytree_ops(self):
2294*da0073e9SAndroid Build Coastguard Worker        rand_tensor = torch.randn(5, 3)
2295*da0073e9SAndroid Build Coastguard Worker        d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
2296*da0073e9SAndroid Build Coastguard Worker
2297*da0073e9SAndroid Build Coastguard Worker        flattened, spec = pytree.tree_flatten(d)
2298*da0073e9SAndroid Build Coastguard Worker        assert flattened == [3, rand_tensor, 42]
2299*da0073e9SAndroid Build Coastguard Worker
2300*da0073e9SAndroid Build Coastguard Worker        unflattened = pytree.tree_unflatten(flattened, spec)
2301*da0073e9SAndroid Build Coastguard Worker        assert unflattened == d
2302*da0073e9SAndroid Build Coastguard Worker        assert isinstance(unflattened, immutable_dict)
2303*da0073e9SAndroid Build Coastguard Worker
2304*da0073e9SAndroid Build Coastguard Worker    def test_move_before(self):
2305*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2306*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2307*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2308*da0073e9SAndroid Build Coastguard Worker        output : torch.fx.Node = graph.output(b)
2309*da0073e9SAndroid Build Coastguard Worker
2310*da0073e9SAndroid Build Coastguard Worker        neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2311*da0073e9SAndroid Build Coastguard Worker        _, *relu_args = b.args
2312*da0073e9SAndroid Build Coastguard Worker        b.args = (neg, *relu_args)
2313*da0073e9SAndroid Build Coastguard Worker        b.prepend(neg)
2314*da0073e9SAndroid Build Coastguard Worker
2315*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2316*da0073e9SAndroid Build Coastguard Worker
2317*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(33, 44)
2318*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2319*da0073e9SAndroid Build Coastguard Worker
2320*da0073e9SAndroid Build Coastguard Worker    def test_prepend_self(self):
2321*da0073e9SAndroid Build Coastguard Worker        graph : torch.fx.Graph = torch.fx.Graph()
2322*da0073e9SAndroid Build Coastguard Worker        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2323*da0073e9SAndroid Build Coastguard Worker        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2324*da0073e9SAndroid Build Coastguard Worker        output : torch.fx.Node = graph.output(b)
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker        b.prepend(b)
2327*da0073e9SAndroid Build Coastguard Worker        x.append(b)
2328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(graph.nodes), 3)
2329*da0073e9SAndroid Build Coastguard Worker
2330*da0073e9SAndroid Build Coastguard Worker    def test_erase_node_error(self):
2331*da0073e9SAndroid Build Coastguard Worker        st = SimpleTest()
2332*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(st)
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
2335*da0073e9SAndroid Build Coastguard Worker            # Test deleting with uses both in another Node and at the output
2336*da0073e9SAndroid Build Coastguard Worker            if node.target in [operator.add, torch.relu]:
2337*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
2338*da0073e9SAndroid Build Coastguard Worker                    traced.graph.erase_node(node)
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker    def test_copy_it(self):
2341*da0073e9SAndroid Build Coastguard Worker        d = immutable_dict([(3, 4), (5, 6)])
2342*da0073e9SAndroid Build Coastguard Worker        l = immutable_list([(3, 4), (5, 6)])
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d, deepcopy(d))
2345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l, deepcopy(l))
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker    def test_get_torch_func_signature(self):
2348*da0073e9SAndroid Build Coastguard Worker        for key in dir(torch):
2349*da0073e9SAndroid Build Coastguard Worker            obj = getattr(torch, key)
2350*da0073e9SAndroid Build Coastguard Worker            if callable(obj):
2351*da0073e9SAndroid Build Coastguard Worker                schemas = get_signature_for_torch_op(obj)
2352*da0073e9SAndroid Build Coastguard Worker
2353*da0073e9SAndroid Build Coastguard Worker    def test_find_uses(self):
2354*da0073e9SAndroid Build Coastguard Worker        graph = torch.fx.Graph()
2355*da0073e9SAndroid Build Coastguard Worker        x = torch.fx.Proxy(graph.placeholder('x'))
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker        y = torch.relu(x)
2358*da0073e9SAndroid Build Coastguard Worker        z = x + x
2359*da0073e9SAndroid Build Coastguard Worker        u = torch.neg(x)
2360*da0073e9SAndroid Build Coastguard Worker        graph.output((y + z + u).node)
2361*da0073e9SAndroid Build Coastguard Worker        graph.lint()
2362*da0073e9SAndroid Build Coastguard Worker
2363*da0073e9SAndroid Build Coastguard Worker        users_of_x = x.node.users
2364*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(users_of_x), 3)
2365*da0073e9SAndroid Build Coastguard Worker        expected_ops = {'relu', 'add', 'neg'}
2366*da0073e9SAndroid Build Coastguard Worker        for use in users_of_x:
2367*da0073e9SAndroid Build Coastguard Worker            assert any(use.name.startswith(prefix) for prefix in expected_ops)
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker    def test_inline_graph(self):
2370*da0073e9SAndroid Build Coastguard Worker        class InlineInto(torch.nn.Module):
2371*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2372*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker        class ToInline(torch.nn.Module):
2375*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2376*da0073e9SAndroid Build Coastguard Worker                return torch.neg(x)
2377*da0073e9SAndroid Build Coastguard Worker
2378*da0073e9SAndroid Build Coastguard Worker        inline_into = symbolic_trace(InlineInto())
2379*da0073e9SAndroid Build Coastguard Worker        to_inline = symbolic_trace(ToInline())
2380*da0073e9SAndroid Build Coastguard Worker
2381*da0073e9SAndroid Build Coastguard Worker        combined_graph = torch.fx.Graph()
2382*da0073e9SAndroid Build Coastguard Worker        output_node = combined_graph.graph_copy(inline_into.graph, {})
2383*da0073e9SAndroid Build Coastguard Worker
2384*da0073e9SAndroid Build Coastguard Worker        input_node = next(iter(to_inline.graph.nodes))
2385*da0073e9SAndroid Build Coastguard Worker        assert input_node and input_node.op == 'placeholder'
2386*da0073e9SAndroid Build Coastguard Worker
2387*da0073e9SAndroid Build Coastguard Worker        val_map = {input_node : output_node}
2388*da0073e9SAndroid Build Coastguard Worker        output = combined_graph.graph_copy(to_inline.graph, val_map)
2389*da0073e9SAndroid Build Coastguard Worker        combined_graph.output(output)
2390*da0073e9SAndroid Build Coastguard Worker
2391*da0073e9SAndroid Build Coastguard Worker        combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
2392*da0073e9SAndroid Build Coastguard Worker
2393*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
2394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(combined_module(input), input.relu().neg())
2395*da0073e9SAndroid Build Coastguard Worker
2396*da0073e9SAndroid Build Coastguard Worker    def test_multi_insert_point(self):
2397*da0073e9SAndroid Build Coastguard Worker        graph = torch.fx.Graph()
2398*da0073e9SAndroid Build Coastguard Worker        x = torch.fx.Proxy(graph.placeholder('x'))
2399*da0073e9SAndroid Build Coastguard Worker        relu = torch.relu(x)
2400*da0073e9SAndroid Build Coastguard Worker
2401*da0073e9SAndroid Build Coastguard Worker        with graph.inserting_before(relu.node):
2402*da0073e9SAndroid Build Coastguard Worker            y = torch.neg(x)
2403*da0073e9SAndroid Build Coastguard Worker            z = torch.tanh(y)
2404*da0073e9SAndroid Build Coastguard Worker
2405*da0073e9SAndroid Build Coastguard Worker        graph.output((relu.node, z.node))
2406*da0073e9SAndroid Build Coastguard Worker        graph.lint()
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker        expected_ops = ['x', 'neg', 'tanh', 'relu']
2409*da0073e9SAndroid Build Coastguard Worker        for node, expected in zip(graph.nodes, expected_ops):
2410*da0073e9SAndroid Build Coastguard Worker            assert expected in node.name
2411*da0073e9SAndroid Build Coastguard Worker
2412*da0073e9SAndroid Build Coastguard Worker    def test_reassign_args_kwargs_uses(self):
2413*da0073e9SAndroid Build Coastguard Worker        graph = torch.fx.Graph()
2414*da0073e9SAndroid Build Coastguard Worker        x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
2415*da0073e9SAndroid Build Coastguard Worker        z = x + y
2416*da0073e9SAndroid Build Coastguard Worker        zed = z + z + z
2417*da0073e9SAndroid Build Coastguard Worker        graph.output(zed.node)
2418*da0073e9SAndroid Build Coastguard Worker        graph.lint()
2419*da0073e9SAndroid Build Coastguard Worker
2420*da0073e9SAndroid Build Coastguard Worker        # zed = z + z + z -> zed = z + z + x
2421*da0073e9SAndroid Build Coastguard Worker        zed.node.args = (zed.node.args[0], x.node)
2422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
2423*da0073e9SAndroid Build Coastguard Worker
2424*da0073e9SAndroid Build Coastguard Worker        # z = x + y -> z = y + y
2425*da0073e9SAndroid Build Coastguard Worker        z.node.args = (y.node, y.node)
2426*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(x.node.users.keys()), [zed.node])
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Worker    def test_trace_function(self):
2429*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
2430*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x) + y
2431*da0073e9SAndroid Build Coastguard Worker
2432*da0073e9SAndroid Build Coastguard Worker        x, y = torch.randn(3, 4), torch.randn(3, 4)
2433*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(foo, (x, y))
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker    def test_trace_return_dataclass(self):
2436*da0073e9SAndroid Build Coastguard Worker        """
2437*da0073e9SAndroid Build Coastguard Worker        Test case for Module that return dataclass
2438*da0073e9SAndroid Build Coastguard Worker        """
2439*da0073e9SAndroid Build Coastguard Worker        from dataclasses import dataclass
2440*da0073e9SAndroid Build Coastguard Worker
2441*da0073e9SAndroid Build Coastguard Worker        @dataclass
2442*da0073e9SAndroid Build Coastguard Worker        class MyOutput:
2443*da0073e9SAndroid Build Coastguard Worker            foo: torch.Tensor
2444*da0073e9SAndroid Build Coastguard Worker            bar: torch.Tensor
2445*da0073e9SAndroid Build Coastguard Worker
2446*da0073e9SAndroid Build Coastguard Worker        class ModuleReturnDataclass(torch.nn.Module):
2447*da0073e9SAndroid Build Coastguard Worker            def forward(self, d : torch.Tensor):
2448*da0073e9SAndroid Build Coastguard Worker                return MyOutput(foo=d + d, bar=d * 3)
2449*da0073e9SAndroid Build Coastguard Worker
2450*da0073e9SAndroid Build Coastguard Worker        module = ModuleReturnDataclass()
2451*da0073e9SAndroid Build Coastguard Worker        traced_graph = symbolic_trace(module).graph
2452*da0073e9SAndroid Build Coastguard Worker        print(traced_graph)
2453*da0073e9SAndroid Build Coastguard Worker
2454*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(module, traced_graph)
2455*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1)
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module(x), gm(x))
2458*da0073e9SAndroid Build Coastguard Worker
2459*da0073e9SAndroid Build Coastguard Worker    def test_trace_return_dataclass_nested(self):
2460*da0073e9SAndroid Build Coastguard Worker        """
2461*da0073e9SAndroid Build Coastguard Worker        Test case for Module that return dataclass
2462*da0073e9SAndroid Build Coastguard Worker        """
2463*da0073e9SAndroid Build Coastguard Worker        from dataclasses import dataclass
2464*da0073e9SAndroid Build Coastguard Worker
2465*da0073e9SAndroid Build Coastguard Worker        @dataclass
2466*da0073e9SAndroid Build Coastguard Worker        class MyOutput:
2467*da0073e9SAndroid Build Coastguard Worker            foo: torch.Tensor
2468*da0073e9SAndroid Build Coastguard Worker            bar: torch.Tensor
2469*da0073e9SAndroid Build Coastguard Worker
2470*da0073e9SAndroid Build Coastguard Worker        class ModuleReturnDataclass(torch.nn.Module):
2471*da0073e9SAndroid Build Coastguard Worker            def forward(self, d : torch.Tensor):
2472*da0073e9SAndroid Build Coastguard Worker                return MyOutput(foo=d + d, bar=d * 3)
2473*da0073e9SAndroid Build Coastguard Worker
2474*da0073e9SAndroid Build Coastguard Worker        class CallsModule(torch.nn.Module):
2475*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2476*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2477*da0073e9SAndroid Build Coastguard Worker                self.m = ModuleReturnDataclass()
2478*da0073e9SAndroid Build Coastguard Worker
2479*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2480*da0073e9SAndroid Build Coastguard Worker                tmp = self.m(x)
2481*da0073e9SAndroid Build Coastguard Worker                return MyOutput(foo=tmp.foo, bar=tmp.bar)
2482*da0073e9SAndroid Build Coastguard Worker
2483*da0073e9SAndroid Build Coastguard Worker        module = CallsModule()
2484*da0073e9SAndroid Build Coastguard Worker        traced_graph = symbolic_trace(module).graph
2485*da0073e9SAndroid Build Coastguard Worker        print(traced_graph)
2486*da0073e9SAndroid Build Coastguard Worker
2487*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(module, traced_graph)
2488*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1)
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module(x), gm(x))
2491*da0073e9SAndroid Build Coastguard Worker
2492*da0073e9SAndroid Build Coastguard Worker    def test_trace_return_namedtuple(self):
2493*da0073e9SAndroid Build Coastguard Worker        """
2494*da0073e9SAndroid Build Coastguard Worker        Test case for Module that return namedtuple
2495*da0073e9SAndroid Build Coastguard Worker        """
2496*da0073e9SAndroid Build Coastguard Worker        class MyOutput(NamedTuple):
2497*da0073e9SAndroid Build Coastguard Worker            foo: torch.Tensor
2498*da0073e9SAndroid Build Coastguard Worker            bar: torch.Tensor
2499*da0073e9SAndroid Build Coastguard Worker
2500*da0073e9SAndroid Build Coastguard Worker        class ModuleReturnNamedTuple(torch.nn.Module):
2501*da0073e9SAndroid Build Coastguard Worker            def forward(self, d : torch.Tensor):
2502*da0073e9SAndroid Build Coastguard Worker                return MyOutput(foo=d, bar=d)
2503*da0073e9SAndroid Build Coastguard Worker
2504*da0073e9SAndroid Build Coastguard Worker        module = ModuleReturnNamedTuple()
2505*da0073e9SAndroid Build Coastguard Worker
2506*da0073e9SAndroid Build Coastguard Worker        traced_graph = symbolic_trace(module).graph
2507*da0073e9SAndroid Build Coastguard Worker        print(traced_graph)
2508*da0073e9SAndroid Build Coastguard Worker
2509*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(module, traced_graph)
2510*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1)
2511*da0073e9SAndroid Build Coastguard Worker
2512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module(x), gm(x))
2513*da0073e9SAndroid Build Coastguard Worker
2514*da0073e9SAndroid Build Coastguard Worker    def test_trace_dict_int_keys(self):
2515*da0073e9SAndroid Build Coastguard Worker        class ModWithDictArg(torch.nn.Module):
2516*da0073e9SAndroid Build Coastguard Worker            def forward(self, d : Dict[int, torch.Tensor]):
2517*da0073e9SAndroid Build Coastguard Worker                return d[42]
2518*da0073e9SAndroid Build Coastguard Worker
2519*da0073e9SAndroid Build Coastguard Worker        class CallsModWithDict(torch.nn.Module):
2520*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2521*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2522*da0073e9SAndroid Build Coastguard Worker                self.m = ModWithDictArg()
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2525*da0073e9SAndroid Build Coastguard Worker                return self.m({42: x})
2526*da0073e9SAndroid Build Coastguard Worker
2527*da0073e9SAndroid Build Coastguard Worker        class MyTracer(torch.fx.Tracer):
2528*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2529*da0073e9SAndroid Build Coastguard Worker                return isinstance(m, ModWithDictArg)
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker        traced_graph = MyTracer().trace(CallsModWithDict())
2532*da0073e9SAndroid Build Coastguard Worker
2533*da0073e9SAndroid Build Coastguard Worker    def test_trace_dict_proxy_keys(self):
2534*da0073e9SAndroid Build Coastguard Worker        class ModWithDictArg(torch.nn.Module):
2535*da0073e9SAndroid Build Coastguard Worker            def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
2536*da0073e9SAndroid Build Coastguard Worker                return d[42]
2537*da0073e9SAndroid Build Coastguard Worker
2538*da0073e9SAndroid Build Coastguard Worker        class CallsModWithDict(torch.nn.Module):
2539*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2540*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2541*da0073e9SAndroid Build Coastguard Worker                self.m = ModWithDictArg()
2542*da0073e9SAndroid Build Coastguard Worker
2543*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2544*da0073e9SAndroid Build Coastguard Worker                return self.m({x: x})
2545*da0073e9SAndroid Build Coastguard Worker
2546*da0073e9SAndroid Build Coastguard Worker        class MyTracer(torch.fx.Tracer):
2547*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2548*da0073e9SAndroid Build Coastguard Worker                return isinstance(m, ModWithDictArg)
2549*da0073e9SAndroid Build Coastguard Worker
2550*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
2551*da0073e9SAndroid Build Coastguard Worker            traced_graph = MyTracer().trace(CallsModWithDict())
2552*da0073e9SAndroid Build Coastguard Worker
2553*da0073e9SAndroid Build Coastguard Worker    def test_module_deepcopy_edit_nodes(self):
2554*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2555*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2556*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
2557*da0073e9SAndroid Build Coastguard Worker
2558*da0073e9SAndroid Build Coastguard Worker        traced1 = symbolic_trace(Foo())
2559*da0073e9SAndroid Build Coastguard Worker        copied = copy.deepcopy(traced1)
2560*da0073e9SAndroid Build Coastguard Worker
2561*da0073e9SAndroid Build Coastguard Worker        for node in copied.graph.nodes:
2562*da0073e9SAndroid Build Coastguard Worker            if node.target == torch.relu:
2563*da0073e9SAndroid Build Coastguard Worker                node.target = torch.neg
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker        copied.recompile()
2566*da0073e9SAndroid Build Coastguard Worker        traced1.recompile()
2567*da0073e9SAndroid Build Coastguard Worker
2568*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(15, 15)
2569*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(traced1(x), torch.relu(x))
2570*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(copied(x), torch.neg(x))
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker    def test_direct_param_use(self):
2573*da0073e9SAndroid Build Coastguard Worker        class TransposeTest(torch.nn.Module):
2574*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2575*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2576*da0073e9SAndroid Build Coastguard Worker                self.b = torch.nn.Parameter(torch.rand(4, 3))
2577*da0073e9SAndroid Build Coastguard Worker
2578*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2579*da0073e9SAndroid Build Coastguard Worker                return self.b
2580*da0073e9SAndroid Build Coastguard Worker
2581*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2582*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2583*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2584*da0073e9SAndroid Build Coastguard Worker                self.a = TransposeTest()
2585*da0073e9SAndroid Build Coastguard Worker
2586*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2587*da0073e9SAndroid Build Coastguard Worker                return self.a.b, self.a.b.t(), self.a.b.view(12)
2588*da0073e9SAndroid Build Coastguard Worker
2589*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(Foo())
2590*da0073e9SAndroid Build Coastguard Worker        assert all('constant' not in node.target for node in traced.graph.nodes)
2591*da0073e9SAndroid Build Coastguard Worker
2592*da0073e9SAndroid Build Coastguard Worker    def test_single_default_arg(self):
2593*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2594*da0073e9SAndroid Build Coastguard Worker            def forward(self, y=1):
2595*da0073e9SAndroid Build Coastguard Worker                return y
2596*da0073e9SAndroid Build Coastguard Worker
2597*da0073e9SAndroid Build Coastguard Worker        m = M()
2598*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, ())
2599*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (3,))
2600*da0073e9SAndroid Build Coastguard Worker
2601*da0073e9SAndroid Build Coastguard Worker    def test_multiple_default_args(self):
2602*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2603*da0073e9SAndroid Build Coastguard Worker            def forward(self, y=1, z=2):
2604*da0073e9SAndroid Build Coastguard Worker                return y + z
2605*da0073e9SAndroid Build Coastguard Worker
2606*da0073e9SAndroid Build Coastguard Worker        m = M()
2607*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, ())
2608*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (3,))
2609*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (3, 4))
2610*da0073e9SAndroid Build Coastguard Worker
2611*da0073e9SAndroid Build Coastguard Worker    def test_regular_and_default_args(self):
2612*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2613*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y=1):
2614*da0073e9SAndroid Build Coastguard Worker                return x + y
2615*da0073e9SAndroid Build Coastguard Worker
2616*da0073e9SAndroid Build Coastguard Worker        m = M()
2617*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (2,))
2618*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, (2, 3))
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker    def test_string_literal_return(self):
2621*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2622*da0073e9SAndroid Build Coastguard Worker            def forward(self):
2623*da0073e9SAndroid Build Coastguard Worker                return "foo"
2624*da0073e9SAndroid Build Coastguard Worker
2625*da0073e9SAndroid Build Coastguard Worker        m = M()
2626*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, ())
2627*da0073e9SAndroid Build Coastguard Worker
2628*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_return_qualname(self):
2629*da0073e9SAndroid Build Coastguard Worker        class NamedTupReturn(torch.nn.Module):
2630*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2631*da0073e9SAndroid Build Coastguard Worker                return MyNamedTup(x, x)
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(NamedTupReturn())
2634*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
2635*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(input), MyNamedTup(input, input))
2636*da0073e9SAndroid Build Coastguard Worker
2637*da0073e9SAndroid Build Coastguard Worker    def test_update_args_kwargs_yells_at_you(self):
2638*da0073e9SAndroid Build Coastguard Worker        symtraced = symbolic_trace(SimpleTest())
2639*da0073e9SAndroid Build Coastguard Worker        node = next(iter(symtraced.graph.nodes))
2640*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
2641*da0073e9SAndroid Build Coastguard Worker            node.__update_args_kwargs((), {})
2642*da0073e9SAndroid Build Coastguard Worker
2643*da0073e9SAndroid Build Coastguard Worker    def test_torchbind_class_attribute_in_fx(self):
2644*da0073e9SAndroid Build Coastguard Worker        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2645*da0073e9SAndroid Build Coastguard Worker            self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
2646*da0073e9SAndroid Build Coastguard Worker
2647*da0073e9SAndroid Build Coastguard Worker        class FooBar1234(torch.nn.Module):
2648*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2649*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2650*da0073e9SAndroid Build Coastguard Worker                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
2651*da0073e9SAndroid Build Coastguard Worker
2652*da0073e9SAndroid Build Coastguard Worker            def forward(self):
2653*da0073e9SAndroid Build Coastguard Worker                return self.f.top()
2654*da0073e9SAndroid Build Coastguard Worker
2655*da0073e9SAndroid Build Coastguard Worker        m = FooBar1234()
2656*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(m, ())
2657*da0073e9SAndroid Build Coastguard Worker
2658*da0073e9SAndroid Build Coastguard Worker    def test_torchbind_class_attribute_in_fx_tensor_arg(self):
2659*da0073e9SAndroid Build Coastguard Worker        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2660*da0073e9SAndroid Build Coastguard Worker            self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
2661*da0073e9SAndroid Build Coastguard Worker
2662*da0073e9SAndroid Build Coastguard Worker        class FooBar2341(torch.nn.Module):
2663*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2664*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2665*da0073e9SAndroid Build Coastguard Worker                self.f = torch.classes._TorchScriptTesting._ReLUClass()
2666*da0073e9SAndroid Build Coastguard Worker
2667*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2668*da0073e9SAndroid Build Coastguard Worker                return self.f.run(x)
2669*da0073e9SAndroid Build Coastguard Worker
2670*da0073e9SAndroid Build Coastguard Worker        m = FooBar2341()
2671*da0073e9SAndroid Build Coastguard Worker
2672*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
2673*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
2674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(input), m(input))
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2677*da0073e9SAndroid Build Coastguard Worker
2678*da0073e9SAndroid Build Coastguard Worker    def test_script_method_trace(self):
2679*da0073e9SAndroid Build Coastguard Worker        class Scripted(torch.nn.Module):
2680*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2681*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
2682*da0073e9SAndroid Build Coastguard Worker
2683*da0073e9SAndroid Build Coastguard Worker        class Holder(torch.nn.Module):
2684*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2685*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2686*da0073e9SAndroid Build Coastguard Worker                self.s = torch.jit.script(Scripted())
2687*da0073e9SAndroid Build Coastguard Worker
2688*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2689*da0073e9SAndroid Build Coastguard Worker                return self.s(x)
2690*da0073e9SAndroid Build Coastguard Worker
2691*da0073e9SAndroid Build Coastguard Worker        h = Holder()
2692*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(h)
2693*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 4)
2694*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(input), h(input))
2695*da0073e9SAndroid Build Coastguard Worker
2696*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2697*da0073e9SAndroid Build Coastguard Worker
2698*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_return_trace(self):
2699*da0073e9SAndroid Build Coastguard Worker        class NamedTupReturn(torch.nn.Module):
2700*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2701*da0073e9SAndroid Build Coastguard Worker                return Pair(x, x)
2702*da0073e9SAndroid Build Coastguard Worker
2703*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(NamedTupReturn())
2704*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
2705*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(input), Pair(input, input))
2706*da0073e9SAndroid Build Coastguard Worker
2707*da0073e9SAndroid Build Coastguard Worker    def test_named_tuple_inlined(self):
2708*da0073e9SAndroid Build Coastguard Worker        class NamedTupMod(torch.nn.Module):
2709*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
2710*da0073e9SAndroid Build Coastguard Worker                return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
2711*da0073e9SAndroid Build Coastguard Worker
2712*da0073e9SAndroid Build Coastguard Worker        m = NamedTupMod()
2713*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
2714*da0073e9SAndroid Build Coastguard Worker        ref = m(input)
2715*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
2716*da0073e9SAndroid Build Coastguard Worker
2717*da0073e9SAndroid Build Coastguard Worker        res = traced(input)
2718*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
2719*da0073e9SAndroid Build Coastguard Worker
2720*da0073e9SAndroid Build Coastguard Worker        # Check Pair NamedTuple works when inlined into the function call.
2721*da0073e9SAndroid Build Coastguard Worker        ph = call_func = None
2722*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
2723*da0073e9SAndroid Build Coastguard Worker            if node.op == "placeholder":
2724*da0073e9SAndroid Build Coastguard Worker                ph = node
2725*da0073e9SAndroid Build Coastguard Worker            elif node.op == "call_function" and node.target == wrapped_named_tup:
2726*da0073e9SAndroid Build Coastguard Worker                node.update_arg(0, Pair(ph, 1.2))
2727*da0073e9SAndroid Build Coastguard Worker                node.update_kwarg("p2", Pair(3.4, ph))
2728*da0073e9SAndroid Build Coastguard Worker                call_func = node
2729*da0073e9SAndroid Build Coastguard Worker                break
2730*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(call_func is not None)
2731*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(call_func.args[0], Pair))
2732*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
2733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
2734*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
2735*da0073e9SAndroid Build Coastguard Worker
2736*da0073e9SAndroid Build Coastguard Worker        traced.graph.eliminate_dead_code()
2737*da0073e9SAndroid Build Coastguard Worker        traced.recompile()
2738*da0073e9SAndroid Build Coastguard Worker        res = traced(input)
2739*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
2740*da0073e9SAndroid Build Coastguard Worker
2741*da0073e9SAndroid Build Coastguard Worker    def test_return_type_exists(self):
2742*da0073e9SAndroid Build Coastguard Worker        class ReturnTypeModule(torch.nn.Module):
2743*da0073e9SAndroid Build Coastguard Worker            def other(self, x: List[str]) -> List[str]:
2744*da0073e9SAndroid Build Coastguard Worker                return x
2745*da0073e9SAndroid Build Coastguard Worker
2746*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: List[str]) -> List[str]:
2747*da0073e9SAndroid Build Coastguard Worker                return self.other(x)
2748*da0073e9SAndroid Build Coastguard Worker
2749*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(ReturnTypeModule())
2750*da0073e9SAndroid Build Coastguard Worker        self.assertIn("-> typing_List[str]", traced._code)
2751*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(traced)
2752*da0073e9SAndroid Build Coastguard Worker        self.assertIn("-> List[str]", scripted.code)
2753*da0073e9SAndroid Build Coastguard Worker
2754*da0073e9SAndroid Build Coastguard Worker    def getitem_inner(self):
2755*da0073e9SAndroid Build Coastguard Worker        class GetItemBase(torch.nn.Module):
2756*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2757*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2758*da0073e9SAndroid Build Coastguard Worker                self.pe = torch.nn.Buffer(torch.randn(8, 8))
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker        class GetItem1(GetItemBase):
2761*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2762*da0073e9SAndroid Build Coastguard Worker                return self.pe[:, :x.size(0)]
2763*da0073e9SAndroid Build Coastguard Worker
2764*da0073e9SAndroid Build Coastguard Worker        class GetItem2(GetItemBase):
2765*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2766*da0073e9SAndroid Build Coastguard Worker                return self.pe[x.size(0)]
2767*da0073e9SAndroid Build Coastguard Worker
2768*da0073e9SAndroid Build Coastguard Worker        class GetItem3(GetItemBase):
2769*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2770*da0073e9SAndroid Build Coastguard Worker                return self.pe[4]  # fx creates `self._tensor_constant0` here
2771*da0073e9SAndroid Build Coastguard Worker
2772*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(GetItem1(), [torch.zeros(4)])
2773*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(GetItem2(), [torch.zeros(4)])
2774*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(GetItem3(), [torch.zeros(4)])
2775*da0073e9SAndroid Build Coastguard Worker
2776*da0073e9SAndroid Build Coastguard Worker    @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
2777*da0073e9SAndroid Build Coastguard Worker                         "Will be checked in test_getitem_subproc")
2778*da0073e9SAndroid Build Coastguard Worker    def test_getitem(self):
2779*da0073e9SAndroid Build Coastguard Worker        self.getitem_inner()
2780*da0073e9SAndroid Build Coastguard Worker
2781*da0073e9SAndroid Build Coastguard Worker    def test_getitem_subproc(self):
2782*da0073e9SAndroid Build Coastguard Worker        # need to run this test in a subproc to work around:
2783*da0073e9SAndroid Build Coastguard Worker        #   https://github.com/pytorch/pytorch/issues/50710
2784*da0073e9SAndroid Build Coastguard Worker        proc = Process(target=run_getitem_target)
2785*da0073e9SAndroid Build Coastguard Worker        proc.start()
2786*da0073e9SAndroid Build Coastguard Worker        proc.join()
2787*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(proc.exitcode, 0)
2788*da0073e9SAndroid Build Coastguard Worker
2789*da0073e9SAndroid Build Coastguard Worker    def test_user_friendly_call_provenance_with_function(self):
2790*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2791*da0073e9SAndroid Build Coastguard Worker            return wrapper_fn(x)
2792*da0073e9SAndroid Build Coastguard Worker
2793*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(fn)
2794*da0073e9SAndroid Build Coastguard Worker
2795*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2796*da0073e9SAndroid Build Coastguard Worker                                    "being compiled since it was called"
2797*da0073e9SAndroid Build Coastguard Worker                                    " from 'fn.forward'"):
2798*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(traced)
2799*da0073e9SAndroid Build Coastguard Worker
2800*da0073e9SAndroid Build Coastguard Worker    def test_user_friendly_call_provenance_with_module(self):
2801*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2802*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2803*da0073e9SAndroid Build Coastguard Worker                return wrapper_fn(x)
2804*da0073e9SAndroid Build Coastguard Worker
2805*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(M())
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2808*da0073e9SAndroid Build Coastguard Worker                                    "being compiled since it was called"
2809*da0073e9SAndroid Build Coastguard Worker                                    " from 'M.forward'"):
2810*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(traced)
2811*da0073e9SAndroid Build Coastguard Worker
2812*da0073e9SAndroid Build Coastguard Worker    def test_snake_case(self):
2813*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2814*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2815*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2816*da0073e9SAndroid Build Coastguard Worker                self.activations = torch.nn.ModuleDict([
2817*da0073e9SAndroid Build Coastguard Worker                    ["snake_case", torch.nn.ReLU()],
2818*da0073e9SAndroid Build Coastguard Worker                    ["PascalCase", torch.nn.LeakyReLU()],
2819*da0073e9SAndroid Build Coastguard Worker                    ["ALL_CAPS", torch.nn.PReLU()]
2820*da0073e9SAndroid Build Coastguard Worker                ])
2821*da0073e9SAndroid Build Coastguard Worker
2822*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2823*da0073e9SAndroid Build Coastguard Worker                a = self.activations["snake_case"](x)
2824*da0073e9SAndroid Build Coastguard Worker                b = self.activations["PascalCase"](x)
2825*da0073e9SAndroid Build Coastguard Worker                c = self.activations["ALL_CAPS"](x)
2826*da0073e9SAndroid Build Coastguard Worker                return a, b, c
2827*da0073e9SAndroid Build Coastguard Worker
2828*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(M())
2829*da0073e9SAndroid Build Coastguard Worker
2830*da0073e9SAndroid Build Coastguard Worker        check = [
2831*da0073e9SAndroid Build Coastguard Worker            ("activations_snake_case", "activations.snake_case"),
2832*da0073e9SAndroid Build Coastguard Worker            ("activations_pascal_case", "activations.PascalCase"),
2833*da0073e9SAndroid Build Coastguard Worker            ("activations_all_caps", "activations.ALL_CAPS")
2834*da0073e9SAndroid Build Coastguard Worker        ]
2835*da0073e9SAndroid Build Coastguard Worker
2836*da0073e9SAndroid Build Coastguard Worker        i = 0
2837*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
2838*da0073e9SAndroid Build Coastguard Worker            if node.op == "placeholder" or node.op == "output":
2839*da0073e9SAndroid Build Coastguard Worker                continue
2840*da0073e9SAndroid Build Coastguard Worker            name = check[i][0]
2841*da0073e9SAndroid Build Coastguard Worker            target = check[i][1]
2842*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(name, node.name)
2843*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(target, node.target)
2844*da0073e9SAndroid Build Coastguard Worker            i += 1
2845*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(i, 3)
2846*da0073e9SAndroid Build Coastguard Worker
2847*da0073e9SAndroid Build Coastguard Worker    def test_no_mutation(self):
2848*da0073e9SAndroid Build Coastguard Worker        from torch.fx.immutable_collections import immutable_list
2849*da0073e9SAndroid Build Coastguard Worker        x = immutable_list([3, 4])
2850*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "new_args"):
2851*da0073e9SAndroid Build Coastguard Worker            x[0] = 4
2852*da0073e9SAndroid Build Coastguard Worker
2853*da0073e9SAndroid Build Coastguard Worker    def test_partial_trace(self):
2854*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2855*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
2856*da0073e9SAndroid Build Coastguard Worker                if y:
2857*da0073e9SAndroid Build Coastguard Worker                    return 2 * x
2858*da0073e9SAndroid Build Coastguard Worker                else:
2859*da0073e9SAndroid Build Coastguard Worker                    return x
2860*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
2861*da0073e9SAndroid Build Coastguard Worker        mod_true = symbolic_trace(mod, concrete_args={'y': True})
2862*da0073e9SAndroid Build Coastguard Worker        mod_false = symbolic_trace(mod, concrete_args={'y': False})
2863*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod_true(3, True), 6)
2864*da0073e9SAndroid Build Coastguard Worker        print(mod_true.code)
2865*da0073e9SAndroid Build Coastguard Worker        assert any(i.target == torch._assert for i in mod_true.graph.nodes)
2866*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2867*da0073e9SAndroid Build Coastguard Worker            mod_true(3, False)
2868*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod_false(3, False), 3)
2869*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2870*da0073e9SAndroid Build Coastguard Worker            mod_false(3, True)
2871*da0073e9SAndroid Build Coastguard Worker
2872*da0073e9SAndroid Build Coastguard Worker        def f_higher(a, f):
2873*da0073e9SAndroid Build Coastguard Worker            return f(a)
2874*da0073e9SAndroid Build Coastguard Worker
2875*da0073e9SAndroid Build Coastguard Worker        nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
2876*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(3, lambda x: x * 2), 6)
2877*da0073e9SAndroid Build Coastguard Worker
2878*da0073e9SAndroid Build Coastguard Worker    def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
2879*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2880*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2881*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2882*da0073e9SAndroid Build Coastguard Worker                self.W = torch.nn.Parameter(torch.randn(5))
2883*da0073e9SAndroid Build Coastguard Worker
2884*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2885*da0073e9SAndroid Build Coastguard Worker                return torch.dot(self.W, x)
2886*da0073e9SAndroid Build Coastguard Worker
2887*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(M())
2888*da0073e9SAndroid Build Coastguard Worker
2889*da0073e9SAndroid Build Coastguard Worker        out = [n for n in traced.graph.nodes if n.op == "output"][-1]
2890*da0073e9SAndroid Build Coastguard Worker        with traced.graph.inserting_before(out):
2891*da0073e9SAndroid Build Coastguard Worker            relu_out = traced.graph.call_method(method_name='relu',
2892*da0073e9SAndroid Build Coastguard Worker                                                args=(out.args[0],))
2893*da0073e9SAndroid Build Coastguard Worker        out.args = (relu_out,)
2894*da0073e9SAndroid Build Coastguard Worker
2895*da0073e9SAndroid Build Coastguard Worker        traced.recompile()
2896*da0073e9SAndroid Build Coastguard Worker
2897*da0073e9SAndroid Build Coastguard Worker        with self.capture_stderr() as captured:
2898*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(TypeError):
2899*da0073e9SAndroid Build Coastguard Worker                traced(5)
2900*da0073e9SAndroid Build Coastguard Worker
2901*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(captured[0],
2902*da0073e9SAndroid Build Coastguard Worker                         r"Call using an FX-traced Module, line .* of the "
2903*da0073e9SAndroid Build Coastguard Worker                         r"traced Module's generated forward function:")
2904*da0073e9SAndroid Build Coastguard Worker
2905*da0073e9SAndroid Build Coastguard Worker    def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
2906*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2907*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2908*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2909*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(3, 4)
2910*da0073e9SAndroid Build Coastguard Worker
2911*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2912*da0073e9SAndroid Build Coastguard Worker                return self.linear(x)
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(M())
2915*da0073e9SAndroid Build Coastguard Worker
2916*da0073e9SAndroid Build Coastguard Worker        # Do not change this to `capture_stderr` or another context
2917*da0073e9SAndroid Build Coastguard Worker        # manager without ensuring that the output is as expected
2918*da0073e9SAndroid Build Coastguard Worker        try:
2919*da0073e9SAndroid Build Coastguard Worker            traced(torch.rand(5, 5))
2920*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
2921*da0073e9SAndroid Build Coastguard Worker            captured = traceback.format_exc()
2922*da0073e9SAndroid Build Coastguard Worker
2923*da0073e9SAndroid Build Coastguard Worker        self.assertNotRegex(captured,
2924*da0073e9SAndroid Build Coastguard Worker                            r"Call using an FX-traced Module, line .* of the "
2925*da0073e9SAndroid Build Coastguard Worker                            r"traced Module's generated forward function:")
2926*da0073e9SAndroid Build Coastguard Worker
2927*da0073e9SAndroid Build Coastguard Worker    def test_graph_module_replicate_for_dp(self):
2928*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2929*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2930*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
2931*da0073e9SAndroid Build Coastguard Worker
2932*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(Foo())
2933*da0073e9SAndroid Build Coastguard Worker
2934*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 3)
2935*da0073e9SAndroid Build Coastguard Worker        out = gm(x)
2936*da0073e9SAndroid Build Coastguard Worker
2937*da0073e9SAndroid Build Coastguard Worker        replica = gm._replicate_for_data_parallel()
2938*da0073e9SAndroid Build Coastguard Worker        out_replica = replica(x)
2939*da0073e9SAndroid Build Coastguard Worker
2940*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(out_replica, out)
2941*da0073e9SAndroid Build Coastguard Worker
2942*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_rewrites_assert(self):
2943*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2944*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, y: int, z: int):
2945*da0073e9SAndroid Build Coastguard Worker                assert y == z
2946*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, x)
2947*da0073e9SAndroid Build Coastguard Worker
2948*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
2949*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(M())
2950*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
2951*da0073e9SAndroid Build Coastguard Worker
2952*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
2953*da0073e9SAndroid Build Coastguard Worker
2954*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_rewrites_assert_with_message(self):
2955*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2956*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, y: int, z: int):
2957*da0073e9SAndroid Build Coastguard Worker                assert y == z, "msg"
2958*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, x)
2959*da0073e9SAndroid Build Coastguard Worker
2960*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
2961*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(M())
2962*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
2963*da0073e9SAndroid Build Coastguard Worker
2964*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
2965*da0073e9SAndroid Build Coastguard Worker
2966*da0073e9SAndroid Build Coastguard Worker    def test_throw_out_variant(self):
2967*da0073e9SAndroid Build Coastguard Worker        def foo(x):
2968*da0073e9SAndroid Build Coastguard Worker            y = torch.rand_like(x)
2969*da0073e9SAndroid Build Coastguard Worker            torch.sigmoid(x, out=y)
2970*da0073e9SAndroid Build Coastguard Worker            return y
2971*da0073e9SAndroid Build Coastguard Worker
2972*da0073e9SAndroid Build Coastguard Worker        class MyTracer(torch.fx.Tracer):
2973*da0073e9SAndroid Build Coastguard Worker            check_mutable_operations = True
2974*da0073e9SAndroid Build Coastguard Worker
2975*da0073e9SAndroid Build Coastguard Worker        tracer = MyTracer()
2976*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
2977*da0073e9SAndroid Build Coastguard Worker            traced_graph = tracer.trace(foo)
2978*da0073e9SAndroid Build Coastguard Worker
2979*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_reassigns_submodules(self):
2980*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2981*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2982*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2983*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(100)
2984*da0073e9SAndroid Build Coastguard Worker
2985*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor):
2986*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, x)
2987*da0073e9SAndroid Build Coastguard Worker
2988*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
2989*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(M())
2990*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
2991*da0073e9SAndroid Build Coastguard Worker
2992*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
2993*da0073e9SAndroid Build Coastguard Worker
2994*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_wrap(self):
2995*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
2996*da0073e9SAndroid Build Coastguard Worker
2997*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
2998*da0073e9SAndroid Build Coastguard Worker            return (
2999*da0073e9SAndroid Build Coastguard Worker                a_lifted_leaf((4, y), 3)
3000*da0073e9SAndroid Build Coastguard Worker                + a_lifted_leaf((3, 4), 5)
3001*da0073e9SAndroid Build Coastguard Worker                + a_lifted_leaf((y, y), y)
3002*da0073e9SAndroid Build Coastguard Worker            )
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
3005*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(to_trace)
3006*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
3007*da0073e9SAndroid Build Coastguard Worker
3008*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a_lifted_leaf", traced.code)
3009*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(27, traced(2))
3010*da0073e9SAndroid Build Coastguard Worker        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
3011*da0073e9SAndroid Build Coastguard Worker
3012*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_wrap_fn_directly(self):
3013*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
3014*da0073e9SAndroid Build Coastguard Worker
3015*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
3016*da0073e9SAndroid Build Coastguard Worker            return (
3017*da0073e9SAndroid Build Coastguard Worker                a_lifted_leaf2((4, y), 3)
3018*da0073e9SAndroid Build Coastguard Worker                + a_lifted_leaf2((3, 4), 5)
3019*da0073e9SAndroid Build Coastguard Worker                + a_lifted_leaf2((y, y), y)
3020*da0073e9SAndroid Build Coastguard Worker            )
3021*da0073e9SAndroid Build Coastguard Worker
3022*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
3023*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(to_trace)
3024*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
3025*da0073e9SAndroid Build Coastguard Worker
3026*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a_lifted_leaf2", traced.code)
3027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(27, traced(2))
3028*da0073e9SAndroid Build Coastguard Worker        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
3029*da0073e9SAndroid Build Coastguard Worker
3030*da0073e9SAndroid Build Coastguard Worker    def test_profiler_ranges_side_effect(self):
3031*da0073e9SAndroid Build Coastguard Worker        g = torch.fx.Graph()
3032*da0073e9SAndroid Build Coastguard Worker        handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',))
3033*da0073e9SAndroid Build Coastguard Worker        g.call_function(torch.ops.profiler._record_function_exit, (handle,))
3034*da0073e9SAndroid Build Coastguard Worker        g.output(None)
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker        found_targets = {}
3037*da0073e9SAndroid Build Coastguard Worker        for node in g.nodes:
3038*da0073e9SAndroid Build Coastguard Worker            if node.op == 'call_function':
3039*da0073e9SAndroid Build Coastguard Worker                found_targets.setdefault(node.target)
3040*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3041*da0073e9SAndroid Build Coastguard Worker            list(found_targets.keys()),
3042*da0073e9SAndroid Build Coastguard Worker            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3043*da0073e9SAndroid Build Coastguard Worker        )
3044*da0073e9SAndroid Build Coastguard Worker
3045*da0073e9SAndroid Build Coastguard Worker        g.eliminate_dead_code()
3046*da0073e9SAndroid Build Coastguard Worker        found_targets = {}
3047*da0073e9SAndroid Build Coastguard Worker        for node in g.nodes:
3048*da0073e9SAndroid Build Coastguard Worker            if node.op == 'call_function':
3049*da0073e9SAndroid Build Coastguard Worker                found_targets.setdefault(node.target)
3050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3051*da0073e9SAndroid Build Coastguard Worker            list(found_targets.keys()),
3052*da0073e9SAndroid Build Coastguard Worker            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3053*da0073e9SAndroid Build Coastguard Worker        )
3054*da0073e9SAndroid Build Coastguard Worker
3055*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_wrapped_via_decorator(self):
3056*da0073e9SAndroid Build Coastguard Worker        class F(torch.nn.Module):
3057*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3058*da0073e9SAndroid Build Coastguard Worker                return wrapped_via_decorator(x)
3059*da0073e9SAndroid Build Coastguard Worker
3060*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
3061*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(F())
3062*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
3063*da0073e9SAndroid Build Coastguard Worker
3064*da0073e9SAndroid Build Coastguard Worker        self.assertIn("wrapped_via_decorator", traced.code)
3065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(0), 1)
3066*da0073e9SAndroid Build Coastguard Worker        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3067*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3068*da0073e9SAndroid Build Coastguard Worker
3069*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
3070*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wrapped_via_decorator(0), 1)
3071*da0073e9SAndroid Build Coastguard Worker
3072*da0073e9SAndroid Build Coastguard Worker        def to_trace(y):
3073*da0073e9SAndroid Build Coastguard Worker            return wrapped_via_decorator(y)
3074*da0073e9SAndroid Build Coastguard Worker
3075*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
3076*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(to_trace)
3077*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
3078*da0073e9SAndroid Build Coastguard Worker
3079*da0073e9SAndroid Build Coastguard Worker        self.assertIn("wrapped_via_decorator", traced.code)
3080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(0), 1)
3081*da0073e9SAndroid Build Coastguard Worker        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3082*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3083*da0073e9SAndroid Build Coastguard Worker
3084*da0073e9SAndroid Build Coastguard Worker        transformed = torch.fx.Transformer(traced).transform()
3085*da0073e9SAndroid Build Coastguard Worker        self.assertIn("wrapped_via_decorator", transformed.code)
3086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(transformed(0), 1)
3087*da0073e9SAndroid Build Coastguard Worker        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3088*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3089*da0073e9SAndroid Build Coastguard Worker
3090*da0073e9SAndroid Build Coastguard Worker    def test_ast_rewriter_wrap_with_submodule(self):
3091*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
3092*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3093*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3094*da0073e9SAndroid Build Coastguard Worker                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3095*da0073e9SAndroid Build Coastguard Worker
3096*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor):
3097*da0073e9SAndroid Build Coastguard Worker                return wrapped_with_submodule(x, self.batchnorm1d)
3098*da0073e9SAndroid Build Coastguard Worker
3099*da0073e9SAndroid Build Coastguard Worker        ast_rewriter = RewritingTracer()
3100*da0073e9SAndroid Build Coastguard Worker        graph = ast_rewriter.trace(M())
3101*da0073e9SAndroid Build Coastguard Worker        traced = GraphModule(ast_rewriter.root, graph, "gm")
3102*da0073e9SAndroid Build Coastguard Worker
3103*da0073e9SAndroid Build Coastguard Worker        self.assertIn("wrapped_with_submodule", traced.code)
3104*da0073e9SAndroid Build Coastguard Worker
3105*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 2)
3106*da0073e9SAndroid Build Coastguard Worker        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_batchnorm1d(input), traced(input))
3108*da0073e9SAndroid Build Coastguard Worker
3109*da0073e9SAndroid Build Coastguard Worker    def test_submodule_manipulation_API(self):
3110*da0073e9SAndroid Build Coastguard Worker        class C(torch.nn.Module):
3111*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3112*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3113*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
3114*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(2, 3))
3115*da0073e9SAndroid Build Coastguard Worker
3116*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3117*da0073e9SAndroid Build Coastguard Worker                return self.conv(torch.cat([self.param, x]))
3118*da0073e9SAndroid Build Coastguard Worker
3119*da0073e9SAndroid Build Coastguard Worker        class B(torch.nn.Module):
3120*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3121*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3122*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(100, 200)
3123*da0073e9SAndroid Build Coastguard Worker                self.buf = torch.nn.Buffer(torch.randn(2, 3))
3124*da0073e9SAndroid Build Coastguard Worker                self.net_c = C()
3125*da0073e9SAndroid Build Coastguard Worker
3126*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3127*da0073e9SAndroid Build Coastguard Worker                return self.linear(torch.cat([self.buf, self.net_c(x)]))
3128*da0073e9SAndroid Build Coastguard Worker
3129*da0073e9SAndroid Build Coastguard Worker        class A(torch.nn.Module):
3130*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3131*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3132*da0073e9SAndroid Build Coastguard Worker                self.net_b = B()
3133*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(2, 3))
3134*da0073e9SAndroid Build Coastguard Worker
3135*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3136*da0073e9SAndroid Build Coastguard Worker                return self.net_b(x) + self.param
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker        a = symbolic_trace(A())
3139*da0073e9SAndroid Build Coastguard Worker
3140*da0073e9SAndroid Build Coastguard Worker        a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
3141*da0073e9SAndroid Build Coastguard Worker
3142*da0073e9SAndroid Build Coastguard Worker        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
3143*da0073e9SAndroid Build Coastguard Worker        with a.graph.inserting_before(conv):
3144*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
3145*da0073e9SAndroid Build Coastguard Worker                dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
3146*da0073e9SAndroid Build Coastguard Worker                                              args=conv.args)
3147*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 0)
3148*da0073e9SAndroid Build Coastguard Worker
3149*da0073e9SAndroid Build Coastguard Worker        conv.replace_all_uses_with(dropout)
3150*da0073e9SAndroid Build Coastguard Worker        a.graph.erase_node(conv)
3151*da0073e9SAndroid Build Coastguard Worker        a.recompile()
3152*da0073e9SAndroid Build Coastguard Worker
3153*da0073e9SAndroid Build Coastguard Worker        def module_exists(gm: GraphModule, path: str) -> bool:
3154*da0073e9SAndroid Build Coastguard Worker            return any(path == name for name, _ in gm.named_modules())
3155*da0073e9SAndroid Build Coastguard Worker
3156*da0073e9SAndroid Build Coastguard Worker        def parameter_exists(gm: GraphModule, path: str) -> bool:
3157*da0073e9SAndroid Build Coastguard Worker            return (any(path == name for name, _ in gm.named_parameters())
3158*da0073e9SAndroid Build Coastguard Worker                    and any(path == name for name in gm.state_dict().keys()))
3159*da0073e9SAndroid Build Coastguard Worker
3160*da0073e9SAndroid Build Coastguard Worker        def buffer_exists(gm: GraphModule, path: str) -> bool:
3161*da0073e9SAndroid Build Coastguard Worker            return (any(path == name for name, _ in gm.named_buffers())
3162*da0073e9SAndroid Build Coastguard Worker                    and any(path == name for name in gm.state_dict().keys()))
3163*da0073e9SAndroid Build Coastguard Worker
3164*da0073e9SAndroid Build Coastguard Worker        # Test that we added the "dropout" submodule
3165*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
3166*da0073e9SAndroid Build Coastguard Worker
3167*da0073e9SAndroid Build Coastguard Worker        # Test `get_submodule` with an added submodule
3168*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
3169*da0073e9SAndroid Build Coastguard Worker
3170*da0073e9SAndroid Build Coastguard Worker        # Test that the "conv" submodule is still there
3171*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(module_exists(a, "net_b.net_c.conv"))
3172*da0073e9SAndroid Build Coastguard Worker
3173*da0073e9SAndroid Build Coastguard Worker        # Test `get_submodule` with an original module
3174*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
3175*da0073e9SAndroid Build Coastguard Worker
3176*da0073e9SAndroid Build Coastguard Worker        # Test that the "conv" node is NOT still there
3177*da0073e9SAndroid Build Coastguard Worker        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
3178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(conv, [])
3179*da0073e9SAndroid Build Coastguard Worker
3180*da0073e9SAndroid Build Coastguard Worker        a.delete_submodule("net_b.net_c.conv")
3181*da0073e9SAndroid Build Coastguard Worker
3182*da0073e9SAndroid Build Coastguard Worker        # Test that the "conv" submodule is now gone
3183*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(module_exists(a, "net_b.net_c.conv"))
3184*da0073e9SAndroid Build Coastguard Worker
3185*da0073e9SAndroid Build Coastguard Worker        # Test `get_submodule` with a deleted submodule
3186*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AttributeError, "has no attribute "
3187*da0073e9SAndroid Build Coastguard Worker                                    "`conv`"):
3188*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
3189*da0073e9SAndroid Build Coastguard Worker
3190*da0073e9SAndroid Build Coastguard Worker        # Test `get_attr` warnings
3191*da0073e9SAndroid Build Coastguard Worker        cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
3192*da0073e9SAndroid Build Coastguard Worker
3193*da0073e9SAndroid Build Coastguard Worker        with a.graph.inserting_before(cat):
3194*da0073e9SAndroid Build Coastguard Worker
3195*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
3196*da0073e9SAndroid Build Coastguard Worker                param = a.graph.get_attr(qualified_name="net_b.net_c.param")
3197*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 0)
3198*da0073e9SAndroid Build Coastguard Worker
3199*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "Attempted to "
3200*da0073e9SAndroid Build Coastguard Worker                                       "insert a get_attr Node with no "
3201*da0073e9SAndroid Build Coastguard Worker                                       "underlying reference in the "
3202*da0073e9SAndroid Build Coastguard Worker                                       "owning GraphModule"):
3203*da0073e9SAndroid Build Coastguard Worker                bad_param = a.graph.get_attr(qualified_name="net_b.param")
3204*da0073e9SAndroid Build Coastguard Worker                a.graph.erase_node(bad_param)
3205*da0073e9SAndroid Build Coastguard Worker
3206*da0073e9SAndroid Build Coastguard Worker        cat.args = (*cat.args, param)
3207*da0073e9SAndroid Build Coastguard Worker
3208*da0073e9SAndroid Build Coastguard Worker        a.recompile()
3209*da0073e9SAndroid Build Coastguard Worker
3210*da0073e9SAndroid Build Coastguard Worker        a.graph.lint()
3211*da0073e9SAndroid Build Coastguard Worker
3212*da0073e9SAndroid Build Coastguard Worker        # Test `get_parameter`
3213*da0073e9SAndroid Build Coastguard Worker        a.get_parameter("net_b.net_c.param")
3214*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AttributeError, "is not an "
3215*da0073e9SAndroid Build Coastguard Worker                                    "nn.Parameter"):
3216*da0073e9SAndroid Build Coastguard Worker            a.get_parameter("net_b.buf")
3217*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AttributeError, "has no attribute "
3218*da0073e9SAndroid Build Coastguard Worker                                    "`param`"):
3219*da0073e9SAndroid Build Coastguard Worker            a.get_parameter("net_b.param")
3220*da0073e9SAndroid Build Coastguard Worker
3221*da0073e9SAndroid Build Coastguard Worker        # Test `get_buffer`
3222*da0073e9SAndroid Build Coastguard Worker        a.get_buffer("net_b.buf")
3223*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AttributeError, "is not a "
3224*da0073e9SAndroid Build Coastguard Worker                                    "buffer"):
3225*da0073e9SAndroid Build Coastguard Worker            a.get_buffer("net_b.net_c.param")
3226*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AttributeError, "has no attribute "
3227*da0073e9SAndroid Build Coastguard Worker                                    "`buf`"):
3228*da0073e9SAndroid Build Coastguard Worker            a.get_buffer("net_b.net_c.buf")
3229*da0073e9SAndroid Build Coastguard Worker
3230*da0073e9SAndroid Build Coastguard Worker        # Test non-nested attributes
3231*da0073e9SAndroid Build Coastguard Worker        a.get_submodule("")
3232*da0073e9SAndroid Build Coastguard Worker        a.get_parameter("param")
3233*da0073e9SAndroid Build Coastguard Worker
3234*da0073e9SAndroid Build Coastguard Worker        # Insert some unused submodules
3235*da0073e9SAndroid Build Coastguard Worker        a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
3236*da0073e9SAndroid Build Coastguard Worker        a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
3237*da0073e9SAndroid Build Coastguard Worker        a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
3238*da0073e9SAndroid Build Coastguard Worker        a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
3239*da0073e9SAndroid Build Coastguard Worker
3240*da0073e9SAndroid Build Coastguard Worker        # Garbage collection
3241*da0073e9SAndroid Build Coastguard Worker        a.delete_all_unused_submodules()
3242*da0073e9SAndroid Build Coastguard Worker
3243*da0073e9SAndroid Build Coastguard Worker        # Test that all the unused submodules are gone
3244*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(module_exists(a, "net_b.embedding"))
3245*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
3246*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
3247*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(module_exists(a, "batch_norm_2d"))
3248*da0073e9SAndroid Build Coastguard Worker
3249*da0073e9SAndroid Build Coastguard Worker        # Test that we didn't delete any unused Parameters or buffers
3250*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
3251*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(buffer_exists(a, "net_b.buf"))
3252*da0073e9SAndroid Build Coastguard Worker
3253*da0073e9SAndroid Build Coastguard Worker        a.graph.lint()
3254*da0073e9SAndroid Build Coastguard Worker
3255*da0073e9SAndroid Build Coastguard Worker    def test_delete_unused_submodules_leaf(self):
3256*da0073e9SAndroid Build Coastguard Worker        class SubModule(torch.nn.Module):
3257*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3258*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3259*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(10, 10)
3260*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
3261*da0073e9SAndroid Build Coastguard Worker
3262*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3263*da0073e9SAndroid Build Coastguard Worker                x = self.linear(x)
3264*da0073e9SAndroid Build Coastguard Worker                x = self.relu(x)
3265*da0073e9SAndroid Build Coastguard Worker                return x
3266*da0073e9SAndroid Build Coastguard Worker
3267*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
3268*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3269*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3270*da0073e9SAndroid Build Coastguard Worker                self.submod = SubModule()
3271*da0073e9SAndroid Build Coastguard Worker
3272*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3273*da0073e9SAndroid Build Coastguard Worker                x = self.submod(x)
3274*da0073e9SAndroid Build Coastguard Worker                return x
3275*da0073e9SAndroid Build Coastguard Worker
3276*da0073e9SAndroid Build Coastguard Worker        model = Model()
3277*da0073e9SAndroid Build Coastguard Worker
3278*da0073e9SAndroid Build Coastguard Worker        class MyCustomTracer(torch.fx.Tracer):
3279*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
3280*da0073e9SAndroid Build Coastguard Worker                return module_qualified_name == "submod"
3281*da0073e9SAndroid Build Coastguard Worker
3282*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(1, 10)
3283*da0073e9SAndroid Build Coastguard Worker        traced_graph = MyCustomTracer().trace(model)
3284*da0073e9SAndroid Build Coastguard Worker        gm2 = torch.fx.GraphModule(model, traced_graph)
3285*da0073e9SAndroid Build Coastguard Worker        gm2.delete_all_unused_submodules()
3286*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(gm2(inputs), model(inputs))
3287*da0073e9SAndroid Build Coastguard Worker
3288*da0073e9SAndroid Build Coastguard Worker    def test_fx_stateless(self):
3289*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
3290*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3291*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3292*da0073e9SAndroid Build Coastguard Worker                self.l1 = torch.nn.Linear(1, 1)
3293*da0073e9SAndroid Build Coastguard Worker                self.buffer = torch.nn.Buffer(torch.ones(1))
3294*da0073e9SAndroid Build Coastguard Worker
3295*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3296*da0073e9SAndroid Build Coastguard Worker                return self.l1(x) + self.buffer
3297*da0073e9SAndroid Build Coastguard Worker
3298*da0073e9SAndroid Build Coastguard Worker        module = MockModule()
3299*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 1))
3300*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor([[1.0]], requires_grad=True)
3301*da0073e9SAndroid Build Coastguard Worker        bias = torch.tensor([0.0], requires_grad=True)
3302*da0073e9SAndroid Build Coastguard Worker        buffer = torch.tensor([0.0])
3303*da0073e9SAndroid Build Coastguard Worker        parameters = {'l1.weight': weight,
3304*da0073e9SAndroid Build Coastguard Worker                      'l1.bias': bias,
3305*da0073e9SAndroid Build Coastguard Worker                      'buffer': buffer}
3306*da0073e9SAndroid Build Coastguard Worker        fx_module = torch.fx.symbolic_trace(module)
3307*da0073e9SAndroid Build Coastguard Worker        res = torch.func.functional_call(fx_module, parameters, x)
3308*da0073e9SAndroid Build Coastguard Worker        res.backward()
3309*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(weight.grad)
3310*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(bias.grad)
3311*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(buffer.grad)
3312*da0073e9SAndroid Build Coastguard Worker        # Gradient was not calculated for the module stated and buffers
3313*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.l1.weight.grad)
3314*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.l1.bias.grad)
3315*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.buffer.grad)
3316*da0073e9SAndroid Build Coastguard Worker
3317*da0073e9SAndroid Build Coastguard Worker    def test_tracing_graphmodules_as_leaf_submodules(self):
3318*da0073e9SAndroid Build Coastguard Worker        class A(torch.nn.Module):
3319*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
3320*da0073e9SAndroid Build Coastguard Worker                return t + t
3321*da0073e9SAndroid Build Coastguard Worker
3322*da0073e9SAndroid Build Coastguard Worker        class B(torch.nn.Module):
3323*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3324*da0073e9SAndroid Build Coastguard Worker                super(type(self), self).__init__()
3325*da0073e9SAndroid Build Coastguard Worker                self.calling = False
3326*da0073e9SAndroid Build Coastguard Worker                self.called = False
3327*da0073e9SAndroid Build Coastguard Worker
3328*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
3329*da0073e9SAndroid Build Coastguard Worker                if self.calling:
3330*da0073e9SAndroid Build Coastguard Worker                    return t - t
3331*da0073e9SAndroid Build Coastguard Worker                else:
3332*da0073e9SAndroid Build Coastguard Worker                    return t + t
3333*da0073e9SAndroid Build Coastguard Worker
3334*da0073e9SAndroid Build Coastguard Worker            def __call__(self, *args):
3335*da0073e9SAndroid Build Coastguard Worker                self.called = True
3336*da0073e9SAndroid Build Coastguard Worker                self.calling = True
3337*da0073e9SAndroid Build Coastguard Worker                return super(type(self), self).__call__(*args)
3338*da0073e9SAndroid Build Coastguard Worker                self.calling = False
3339*da0073e9SAndroid Build Coastguard Worker
3340*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
3341*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a, b):
3342*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3343*da0073e9SAndroid Build Coastguard Worker                self.a = a
3344*da0073e9SAndroid Build Coastguard Worker                self.b = b
3345*da0073e9SAndroid Build Coastguard Worker
3346*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
3347*da0073e9SAndroid Build Coastguard Worker                x = self.a(t)
3348*da0073e9SAndroid Build Coastguard Worker                y = self.b(t)
3349*da0073e9SAndroid Build Coastguard Worker                return x + y
3350*da0073e9SAndroid Build Coastguard Worker
3351*da0073e9SAndroid Build Coastguard Worker        class LeafTracer(Tracer):
3352*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, module, name):
3353*da0073e9SAndroid Build Coastguard Worker                return True
3354*da0073e9SAndroid Build Coastguard Worker
3355*da0073e9SAndroid Build Coastguard Worker        class LeafTracerNotB(Tracer):
3356*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(self, module, name):
3357*da0073e9SAndroid Build Coastguard Worker                return False if "b" in name else True
3358*da0073e9SAndroid Build Coastguard Worker
3359*da0073e9SAndroid Build Coastguard Worker        # Recompile calls added "for fun", since they
3360*da0073e9SAndroid Build Coastguard Worker        # chain __call__ wrappers.
3361*da0073e9SAndroid Build Coastguard Worker
3362*da0073e9SAndroid Build Coastguard Worker        #
3363*da0073e9SAndroid Build Coastguard Worker        # Test: B as a regular, non-leaf module
3364*da0073e9SAndroid Build Coastguard Worker        #
3365*da0073e9SAndroid Build Coastguard Worker        a = symbolic_trace(A())
3366*da0073e9SAndroid Build Coastguard Worker        a.recompile()
3367*da0073e9SAndroid Build Coastguard Worker        m = M(a, B())
3368*da0073e9SAndroid Build Coastguard Worker        graph = LeafTracerNotB().trace(m)
3369*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(m, graph)
3370*da0073e9SAndroid Build Coastguard Worker        gm.recompile()
3371*da0073e9SAndroid Build Coastguard Worker
3372*da0073e9SAndroid Build Coastguard Worker        # Test graphmodule/submodule a is not inlined.
3373*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3374*da0073e9SAndroid Build Coastguard Worker        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3375*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(match) == 1)
3376*da0073e9SAndroid Build Coastguard Worker
3377*da0073e9SAndroid Build Coastguard Worker        # Test submodule b is not treated as leaf.
3378*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(gm, "b"))
3379*da0073e9SAndroid Build Coastguard Worker
3380*da0073e9SAndroid Build Coastguard Worker        # Test assert custom __call__ on submodule b was honored.
3381*da0073e9SAndroid Build Coastguard Worker        match = [
3382*da0073e9SAndroid Build Coastguard Worker            n
3383*da0073e9SAndroid Build Coastguard Worker            for n in gm.graph.nodes
3384*da0073e9SAndroid Build Coastguard Worker            if n.op == "call_function" and n.target == operator.sub
3385*da0073e9SAndroid Build Coastguard Worker        ]
3386*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(match) == 1)
3387*da0073e9SAndroid Build Coastguard Worker
3388*da0073e9SAndroid Build Coastguard Worker        #
3389*da0073e9SAndroid Build Coastguard Worker        # Test: B as a regular, leaf module
3390*da0073e9SAndroid Build Coastguard Worker        # symbolic_trace should only patch torch.nn.Module.__call__,
3391*da0073e9SAndroid Build Coastguard Worker        # which means B.__call__ should still execute
3392*da0073e9SAndroid Build Coastguard Worker        #
3393*da0073e9SAndroid Build Coastguard Worker        a = symbolic_trace(A())
3394*da0073e9SAndroid Build Coastguard Worker        a.recompile()
3395*da0073e9SAndroid Build Coastguard Worker        b = B()
3396*da0073e9SAndroid Build Coastguard Worker        m = M(a, b)
3397*da0073e9SAndroid Build Coastguard Worker        graph = LeafTracer().trace(m)
3398*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(m, graph)
3399*da0073e9SAndroid Build Coastguard Worker        gm.recompile()
3400*da0073e9SAndroid Build Coastguard Worker
3401*da0073e9SAndroid Build Coastguard Worker        # Test graphmodule/submodule a is not inlined.
3402*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3403*da0073e9SAndroid Build Coastguard Worker        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3404*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(match) == 1)
3405*da0073e9SAndroid Build Coastguard Worker
3406*da0073e9SAndroid Build Coastguard Worker        # Test submodule b is leaf:
3407*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3408*da0073e9SAndroid Build Coastguard Worker        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3409*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(match) == 1)
3410*da0073e9SAndroid Build Coastguard Worker
3411*da0073e9SAndroid Build Coastguard Worker        # Test b.__call__ was run
3412*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.called)
3413*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gm.get_submodule("b").called)
3414*da0073e9SAndroid Build Coastguard Worker
3415*da0073e9SAndroid Build Coastguard Worker        #
3416*da0073e9SAndroid Build Coastguard Worker        # Test: B as GraphModule leaf
3417*da0073e9SAndroid Build Coastguard Worker        # __call__ not honored since symbolic_trace directly invokes forward()
3418*da0073e9SAndroid Build Coastguard Worker        #
3419*da0073e9SAndroid Build Coastguard Worker        a = symbolic_trace(A())
3420*da0073e9SAndroid Build Coastguard Worker        a.recompile()
3421*da0073e9SAndroid Build Coastguard Worker        b = symbolic_trace(B())
3422*da0073e9SAndroid Build Coastguard Worker        b.recompile()
3423*da0073e9SAndroid Build Coastguard Worker        m = M(a, b)
3424*da0073e9SAndroid Build Coastguard Worker        graph = LeafTracer().trace(m)
3425*da0073e9SAndroid Build Coastguard Worker        gm = GraphModule(m, graph)
3426*da0073e9SAndroid Build Coastguard Worker        gm.recompile()
3427*da0073e9SAndroid Build Coastguard Worker
3428*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3429*da0073e9SAndroid Build Coastguard Worker        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3430*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(match) == 1)
3431*da0073e9SAndroid Build Coastguard Worker
3432*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3433*da0073e9SAndroid Build Coastguard Worker        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3434*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(match) == 1)
3435*da0073e9SAndroid Build Coastguard Worker
3436*da0073e9SAndroid Build Coastguard Worker    def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
3437*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
3438*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3439*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3440*da0073e9SAndroid Build Coastguard Worker                self.my_buff = torch.nn.Buffer(torch.rand(3, 4))
3441*da0073e9SAndroid Build Coastguard Worker                self.register_parameter(
3442*da0073e9SAndroid Build Coastguard Worker                    "my_param", torch.nn.Parameter(torch.rand(3, 4))
3443*da0073e9SAndroid Build Coastguard Worker                )
3444*da0073e9SAndroid Build Coastguard Worker
3445*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3446*da0073e9SAndroid Build Coastguard Worker                return x + self.my_buff + self.my_param
3447*da0073e9SAndroid Build Coastguard Worker
3448*da0073e9SAndroid Build Coastguard Worker        mod = MyModule()
3449*da0073e9SAndroid Build Coastguard Worker        mod_traced = symbolic_trace(mod)
3450*da0073e9SAndroid Build Coastguard Worker
3451*da0073e9SAndroid Build Coastguard Worker        # Create new GraphModule based on original, either w/ dict or root module.
3452*da0073e9SAndroid Build Coastguard Worker        orig_buff = mod_traced.get_buffer("my_buff")
3453*da0073e9SAndroid Build Coastguard Worker        orig_param = mod_traced.get_parameter("my_param")
3454*da0073e9SAndroid Build Coastguard Worker        mod_traced_new = GraphModule(
3455*da0073e9SAndroid Build Coastguard Worker            {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
3456*da0073e9SAndroid Build Coastguard Worker            mod_traced.graph,
3457*da0073e9SAndroid Build Coastguard Worker        )
3458*da0073e9SAndroid Build Coastguard Worker
3459*da0073e9SAndroid Build Coastguard Worker        # Check that both my_buff and my_param are found and the same.
3460*da0073e9SAndroid Build Coastguard Worker        try:
3461*da0073e9SAndroid Build Coastguard Worker            new_buff = mod_traced_new.get_buffer("my_buff")
3462*da0073e9SAndroid Build Coastguard Worker        except Exception:
3463*da0073e9SAndroid Build Coastguard Worker            self.fail("Did not find my_buff")
3464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_buff, new_buff)
3465*da0073e9SAndroid Build Coastguard Worker
3466*da0073e9SAndroid Build Coastguard Worker        try:
3467*da0073e9SAndroid Build Coastguard Worker            new_param = mod_traced_new.get_parameter("my_param")
3468*da0073e9SAndroid Build Coastguard Worker        except Exception:
3469*da0073e9SAndroid Build Coastguard Worker            self.fail("Did not find my_param")
3470*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_param, new_param)
3471*da0073e9SAndroid Build Coastguard Worker
3472*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
3473*da0073e9SAndroid Build Coastguard Worker        orig_out = mod_traced(x)
3474*da0073e9SAndroid Build Coastguard Worker        submodules_out = mod_traced_new(x)
3475*da0073e9SAndroid Build Coastguard Worker
3476*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_out, submodules_out)
3477*da0073e9SAndroid Build Coastguard Worker
3478*da0073e9SAndroid Build Coastguard Worker    def test_graph_module_init_buffer_param_copied_dict_init(self):
3479*da0073e9SAndroid Build Coastguard Worker        self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
3480*da0073e9SAndroid Build Coastguard Worker
3481*da0073e9SAndroid Build Coastguard Worker    def test_graph_module_init_buffer_param_copied_mod_init(self):
3482*da0073e9SAndroid Build Coastguard Worker        self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
3483*da0073e9SAndroid Build Coastguard Worker
3484*da0073e9SAndroid Build Coastguard Worker    def test_annotations_with_no_forward_references(self):
3485*da0073e9SAndroid Build Coastguard Worker        class A:
3486*da0073e9SAndroid Build Coastguard Worker            def __call__(self, x: torch.Tensor):
3487*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, x)
3488*da0073e9SAndroid Build Coastguard Worker
3489*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
3490*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
3491*da0073e9SAndroid Build Coastguard Worker                return a(x)
3492*da0073e9SAndroid Build Coastguard Worker
3493*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3494*da0073e9SAndroid Build Coastguard Worker
3495*da0073e9SAndroid Build Coastguard Worker    def test_annotations_with_forward_references(self):
3496*da0073e9SAndroid Build Coastguard Worker        class A:
3497*da0073e9SAndroid Build Coastguard Worker            def __call__(self, x: torch.Tensor):
3498*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, x)
3499*da0073e9SAndroid Build Coastguard Worker
3500*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
3501*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
3502*da0073e9SAndroid Build Coastguard Worker                return a(x)
3503*da0073e9SAndroid Build Coastguard Worker
3504*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3505*da0073e9SAndroid Build Coastguard Worker
3506*da0073e9SAndroid Build Coastguard Worker    def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
3507*da0073e9SAndroid Build Coastguard Worker        class A:
3508*da0073e9SAndroid Build Coastguard Worker            def __call__(self, x: torch.Tensor):
3509*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, x)
3510*da0073e9SAndroid Build Coastguard Worker
3511*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
3512*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
3513*da0073e9SAndroid Build Coastguard Worker                return a(x[0])
3514*da0073e9SAndroid Build Coastguard Worker
3515*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3516*da0073e9SAndroid Build Coastguard Worker
3517*da0073e9SAndroid Build Coastguard Worker    def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
3518*da0073e9SAndroid Build Coastguard Worker        class A:
3519*da0073e9SAndroid Build Coastguard Worker            def __call__(self, x: torch.Tensor):
3520*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, x)
3521*da0073e9SAndroid Build Coastguard Worker
3522*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
3523*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
3524*da0073e9SAndroid Build Coastguard Worker                return a(x)[0]
3525*da0073e9SAndroid Build Coastguard Worker
3526*da0073e9SAndroid Build Coastguard Worker        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3527*da0073e9SAndroid Build Coastguard Worker
3528*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
3529*da0073e9SAndroid Build Coastguard Worker                     "`annotations` is not defined in Python <3.7")
3530*da0073e9SAndroid Build Coastguard Worker    def test_annotation_with_future(self):
3531*da0073e9SAndroid Build Coastguard Worker        try:
3532*da0073e9SAndroid Build Coastguard Worker            import fx.test_future    # noqa: F401
3533*da0073e9SAndroid Build Coastguard Worker        finally:
3534*da0073e9SAndroid Build Coastguard Worker            del sys.modules["__future__"]
3535*da0073e9SAndroid Build Coastguard Worker
3536*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
3537*da0073e9SAndroid Build Coastguard Worker    def test_annotations_empty_tuple(self):
3538*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
3539*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3540*da0073e9SAndroid Build Coastguard Worker                return "foo"
3541*da0073e9SAndroid Build Coastguard Worker
3542*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(Foo())
3543*da0073e9SAndroid Build Coastguard Worker
3544*da0073e9SAndroid Build Coastguard Worker        x = ()
3545*da0073e9SAndroid Build Coastguard Worker        y = ("bar", ())
3546*da0073e9SAndroid Build Coastguard Worker
3547*da0073e9SAndroid Build Coastguard Worker        traced(x, y)
3548*da0073e9SAndroid Build Coastguard Worker
3549*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("_Tuple[()]")   \
3550*da0073e9SAndroid Build Coastguard Worker                   .check("typing_Tuple[str,typing_Tuple[()]]") \
3551*da0073e9SAndroid Build Coastguard Worker                   .run(traced.code)
3552*da0073e9SAndroid Build Coastguard Worker
3553*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(traced)
3554*da0073e9SAndroid Build Coastguard Worker
3555*da0073e9SAndroid Build Coastguard Worker        scripted(x, y)
3556*da0073e9SAndroid Build Coastguard Worker
3557*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Tuple[()]")   \
3558*da0073e9SAndroid Build Coastguard Worker            .check("Tuple[str, Tuple[()]]")    \
3559*da0073e9SAndroid Build Coastguard Worker            .run(scripted.code)
3560*da0073e9SAndroid Build Coastguard Worker
3561*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
3562*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
3563*da0073e9SAndroid Build Coastguard Worker    def test_assert(self):
3564*da0073e9SAndroid Build Coastguard Worker        def f(x):
3565*da0073e9SAndroid Build Coastguard Worker            assert x > 1
3566*da0073e9SAndroid Build Coastguard Worker            return x + 1
3567*da0073e9SAndroid Build Coastguard Worker        try:
3568*da0073e9SAndroid Build Coastguard Worker            torch.fx.proxy.TracerBase.trace_asserts = True
3569*da0073e9SAndroid Build Coastguard Worker            traced = symbolic_trace(f)
3570*da0073e9SAndroid Build Coastguard Worker        finally:
3571*da0073e9SAndroid Build Coastguard Worker            torch.fx.proxy.TracerBase.trace_asserts = False
3572*da0073e9SAndroid Build Coastguard Worker
3573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(2), traced(2))
3574*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
3575*da0073e9SAndroid Build Coastguard Worker            traced(0)
3576*da0073e9SAndroid Build Coastguard Worker
3577*da0073e9SAndroid Build Coastguard Worker    def test_pytree(self):
3578*da0073e9SAndroid Build Coastguard Worker        # Used to test that you can use your own placeholder class
3579*da0073e9SAndroid Build Coastguard Worker        class PHTest(PHBase):
3580*da0073e9SAndroid Build Coastguard Worker            pass
3581*da0073e9SAndroid Build Coastguard Worker
3582*da0073e9SAndroid Build Coastguard Worker        def f_sum(x):
3583*da0073e9SAndroid Build Coastguard Worker            return sum(x)
3584*da0073e9SAndroid Build Coastguard Worker
3585*da0073e9SAndroid Build Coastguard Worker        def f_sum_dict(x):
3586*da0073e9SAndroid Build Coastguard Worker            out = 0
3587*da0073e9SAndroid Build Coastguard Worker            for v in x.values():
3588*da0073e9SAndroid Build Coastguard Worker                out += v
3589*da0073e9SAndroid Build Coastguard Worker            return out
3590*da0073e9SAndroid Build Coastguard Worker
3591*da0073e9SAndroid Build Coastguard Worker        def f_dict_list_map(x):
3592*da0073e9SAndroid Build Coastguard Worker            new_dict = {}
3593*da0073e9SAndroid Build Coastguard Worker            for k, v in x.items():
3594*da0073e9SAndroid Build Coastguard Worker                new_dict[k] = [i + 1 for i in v]
3595*da0073e9SAndroid Build Coastguard Worker            return new_dict
3596*da0073e9SAndroid Build Coastguard Worker
3597*da0073e9SAndroid Build Coastguard Worker        def f_dict_add(x):
3598*da0073e9SAndroid Build Coastguard Worker            return x['a'] + sum(x['z'])
3599*da0073e9SAndroid Build Coastguard Worker
3600*da0073e9SAndroid Build Coastguard Worker        def f_namedtuple_add(x):
3601*da0073e9SAndroid Build Coastguard Worker            return x.x + x.y
3602*da0073e9SAndroid Build Coastguard Worker
3603*da0073e9SAndroid Build Coastguard Worker        pytree.register_pytree_node(
3604*da0073e9SAndroid Build Coastguard Worker            Foo,
3605*da0073e9SAndroid Build Coastguard Worker            lambda x: ([x.a, x.b], None),
3606*da0073e9SAndroid Build Coastguard Worker            lambda x, _: Foo(x[0], x[1]),
3607*da0073e9SAndroid Build Coastguard Worker        )
3608*da0073e9SAndroid Build Coastguard Worker        fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
3609*da0073e9SAndroid Build Coastguard Worker
3610*da0073e9SAndroid Build Coastguard Worker        def f_custom(x):
3611*da0073e9SAndroid Build Coastguard Worker            return x.a + x.b
3612*da0073e9SAndroid Build Coastguard Worker
3613*da0073e9SAndroid Build Coastguard Worker        def f_custom_dict(x):
3614*da0073e9SAndroid Build Coastguard Worker            return f_sum_dict(x.a) + x.b
3615*da0073e9SAndroid Build Coastguard Worker
3616*da0073e9SAndroid Build Coastguard Worker        def f_return_custom(x):
3617*da0073e9SAndroid Build Coastguard Worker            return Foo(x.b, x.a)
3618*da0073e9SAndroid Build Coastguard Worker
3619*da0073e9SAndroid Build Coastguard Worker        tests = [
3620*da0073e9SAndroid Build Coastguard Worker            (f_sum, [PH, PH, PH]),
3621*da0073e9SAndroid Build Coastguard Worker            (f_sum, []),
3622*da0073e9SAndroid Build Coastguard Worker            (f_sum, [PHTest(), PHTest(), PHTest()]),
3623*da0073e9SAndroid Build Coastguard Worker            (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
3624*da0073e9SAndroid Build Coastguard Worker            (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
3625*da0073e9SAndroid Build Coastguard Worker            (f_dict_list_map, {5: (PH, PH, PH)}),
3626*da0073e9SAndroid Build Coastguard Worker            (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
3627*da0073e9SAndroid Build Coastguard Worker            (f_dict_add, {'a': PH, 'z': []}),
3628*da0073e9SAndroid Build Coastguard Worker            (f_custom, Foo(PH, PH)),
3629*da0073e9SAndroid Build Coastguard Worker            (f_custom, Foo(PH, 3)),
3630*da0073e9SAndroid Build Coastguard Worker            (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
3631*da0073e9SAndroid Build Coastguard Worker            # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
3632*da0073e9SAndroid Build Coastguard Worker            (f_namedtuple_add, Point(PH, PH)),
3633*da0073e9SAndroid Build Coastguard Worker        ]
3634*da0073e9SAndroid Build Coastguard Worker
3635*da0073e9SAndroid Build Coastguard Worker        def verify_pytree(f, inp):
3636*da0073e9SAndroid Build Coastguard Worker            val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
3637*da0073e9SAndroid Build Coastguard Worker            num_flat_args = len(pytree.tree_leaves(inp))
3638*da0073e9SAndroid Build Coastguard Worker            orig_out = f(val)
3639*da0073e9SAndroid Build Coastguard Worker            nf = symbolic_trace(f, concrete_args={'x': inp})
3640*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nf(val), orig_out)
3641*da0073e9SAndroid Build Coastguard Worker
3642*da0073e9SAndroid Build Coastguard Worker            bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3643*da0073e9SAndroid Build Coastguard Worker            bare_fx.graph.set_codegen(CodeGen())
3644*da0073e9SAndroid Build Coastguard Worker            bare_fx.recompile()
3645*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
3646*da0073e9SAndroid Build Coastguard Worker
3647*da0073e9SAndroid Build Coastguard Worker            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3648*da0073e9SAndroid Build Coastguard Worker            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3649*da0073e9SAndroid Build Coastguard Worker
3650*da0073e9SAndroid Build Coastguard Worker            nf = symbolic_trace(nf)
3651*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nf(val), orig_out)
3652*da0073e9SAndroid Build Coastguard Worker            assert "tree_flatten_spec" not in nf.code
3653*da0073e9SAndroid Build Coastguard Worker            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1
3654*da0073e9SAndroid Build Coastguard Worker
3655*da0073e9SAndroid Build Coastguard Worker            nf = symbolic_trace(nf, concrete_args={'x': inp})
3656*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nf(val), orig_out)
3657*da0073e9SAndroid Build Coastguard Worker            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3658*da0073e9SAndroid Build Coastguard Worker            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3659*da0073e9SAndroid Build Coastguard Worker
3660*da0073e9SAndroid Build Coastguard Worker            pickled = pickle.dumps(nf)
3661*da0073e9SAndroid Build Coastguard Worker            nf = pickle.loads(pickled)
3662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nf(val), orig_out)
3663*da0073e9SAndroid Build Coastguard Worker
3664*da0073e9SAndroid Build Coastguard Worker        for f, inp in tests:
3665*da0073e9SAndroid Build Coastguard Worker            verify_pytree(f, inp)
3666*da0073e9SAndroid Build Coastguard Worker
3667*da0073e9SAndroid Build Coastguard Worker    def test_pytree_concrete(self):
3668*da0073e9SAndroid Build Coastguard Worker        def f(b, a):
3669*da0073e9SAndroid Build Coastguard Worker            if b:
3670*da0073e9SAndroid Build Coastguard Worker                return a['a']
3671*da0073e9SAndroid Build Coastguard Worker            else:
3672*da0073e9SAndroid Build Coastguard Worker                return a['z']
3673*da0073e9SAndroid Build Coastguard Worker
3674*da0073e9SAndroid Build Coastguard Worker        inp = {'a': {'a': PH, 'z': PH}, 'b': True}
3675*da0073e9SAndroid Build Coastguard Worker        nf = symbolic_trace(f, concrete_args=inp)
3676*da0073e9SAndroid Build Coastguard Worker        val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(**val), f(**val))
3678*da0073e9SAndroid Build Coastguard Worker
3679*da0073e9SAndroid Build Coastguard Worker        nf = symbolic_trace(nf)
3680*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(**val), f(**val))
3681*da0073e9SAndroid Build Coastguard Worker
3682*da0073e9SAndroid Build Coastguard Worker    def test_metadata_on_ph(self):
3683*da0073e9SAndroid Build Coastguard Worker        def f_sum(a: int, b: int) -> int:
3684*da0073e9SAndroid Build Coastguard Worker            return a + b
3685*da0073e9SAndroid Build Coastguard Worker
3686*da0073e9SAndroid Build Coastguard Worker        # Due to unflattening of dict, the batch argument
3687*da0073e9SAndroid Build Coastguard Worker        # will be split into two separate nodes with the names
3688*da0073e9SAndroid Build Coastguard Worker        # "batch_1" and "batch_2", referring to the keys
3689*da0073e9SAndroid Build Coastguard Worker        # "f1" and "f2" respectively in the dict.
3690*da0073e9SAndroid Build Coastguard Worker        def f_dict(a: Dict[str, str]) -> bool:
3691*da0073e9SAndroid Build Coastguard Worker            return a["f1"] == a["f2"]
3692*da0073e9SAndroid Build Coastguard Worker
3693*da0073e9SAndroid Build Coastguard Worker        def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]):
3694*da0073e9SAndroid Build Coastguard Worker            for node in gm.graph.nodes:
3695*da0073e9SAndroid Build Coastguard Worker                if node.op == "placeholder":
3696*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(node.name in arg_names)
3697*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(node.ph_key in metadata)
3698*da0073e9SAndroid Build Coastguard Worker
3699*da0073e9SAndroid Build Coastguard Worker        verify_metadata(
3700*da0073e9SAndroid Build Coastguard Worker            gm=symbolic_trace(
3701*da0073e9SAndroid Build Coastguard Worker                f_sum,
3702*da0073e9SAndroid Build Coastguard Worker                concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")}
3703*da0073e9SAndroid Build Coastguard Worker            ),
3704*da0073e9SAndroid Build Coastguard Worker            arg_names=["a_1", "b_1"],
3705*da0073e9SAndroid Build Coastguard Worker            metadata=["a", "b"]
3706*da0073e9SAndroid Build Coastguard Worker        )
3707*da0073e9SAndroid Build Coastguard Worker        verify_metadata(
3708*da0073e9SAndroid Build Coastguard Worker            gm=symbolic_trace(
3709*da0073e9SAndroid Build Coastguard Worker                f_dict,
3710*da0073e9SAndroid Build Coastguard Worker                concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}}
3711*da0073e9SAndroid Build Coastguard Worker            ),
3712*da0073e9SAndroid Build Coastguard Worker            arg_names=["a_1", "a_2"],
3713*da0073e9SAndroid Build Coastguard Worker            metadata=["f1", "f2"]
3714*da0073e9SAndroid Build Coastguard Worker        )
3715*da0073e9SAndroid Build Coastguard Worker
3716*da0073e9SAndroid Build Coastguard Worker        # Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag)
3717*da0073e9SAndroid Build Coastguard Worker        class TaggingTracer(Tracer):
3718*da0073e9SAndroid Build Coastguard Worker            def create_node(self, kind : str, target : Union[str, Callable],
3719*da0073e9SAndroid Build Coastguard Worker                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
3720*da0073e9SAndroid Build Coastguard Worker                            type_expr : Optional[Any] = None) -> Node:
3721*da0073e9SAndroid Build Coastguard Worker                n = super().create_node(kind, target, args, kwargs, name)
3722*da0073e9SAndroid Build Coastguard Worker                n.tag = "foo"
3723*da0073e9SAndroid Build Coastguard Worker                return n
3724*da0073e9SAndroid Build Coastguard Worker
3725*da0073e9SAndroid Build Coastguard Worker        class PHWithTag(PHBase):
3726*da0073e9SAndroid Build Coastguard Worker            def __init__(self, tag: str):
3727*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3728*da0073e9SAndroid Build Coastguard Worker
3729*da0073e9SAndroid Build Coastguard Worker                self.tag = tag
3730*da0073e9SAndroid Build Coastguard Worker
3731*da0073e9SAndroid Build Coastguard Worker        g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")})
3732*da0073e9SAndroid Build Coastguard Worker        for n in g.nodes:
3733*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(n, "tag"))
3734*da0073e9SAndroid Build Coastguard Worker            # Ensure that tag is still "foo" and not "bar" (from PHWithTag)
3735*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(n.tag, "foo")
3736*da0073e9SAndroid Build Coastguard Worker
3737*da0073e9SAndroid Build Coastguard Worker    def test_custom_codegen(self):
3738*da0073e9SAndroid Build Coastguard Worker        class ListCodeGen(CodeGen):
3739*da0073e9SAndroid Build Coastguard Worker            def gen_fn_def(self, free_vars, maybe_return_annotation):
3740*da0073e9SAndroid Build Coastguard Worker                lst_unpack = f"""
3741*da0073e9SAndroid Build Coastguard Workerdef forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3742*da0073e9SAndroid Build Coastguard Worker    {', '.join(free_vars)} = args_list"""
3743*da0073e9SAndroid Build Coastguard Worker                return lst_unpack
3744*da0073e9SAndroid Build Coastguard Worker
3745*da0073e9SAndroid Build Coastguard Worker            def additional_globals(self):
3746*da0073e9SAndroid Build Coastguard Worker                return [('List', typing.List)]
3747*da0073e9SAndroid Build Coastguard Worker
3748*da0073e9SAndroid Build Coastguard Worker            def process_inputs(self, *inputs):
3749*da0073e9SAndroid Build Coastguard Worker                assert len(inputs) == 1
3750*da0073e9SAndroid Build Coastguard Worker                return inputs[0]
3751*da0073e9SAndroid Build Coastguard Worker
3752*da0073e9SAndroid Build Coastguard Worker        def f(a, b):
3753*da0073e9SAndroid Build Coastguard Worker            return a + b
3754*da0073e9SAndroid Build Coastguard Worker
3755*da0073e9SAndroid Build Coastguard Worker        nf = symbolic_trace(f)
3756*da0073e9SAndroid Build Coastguard Worker        vals = [torch.randn(3), torch.randn(3)]
3757*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(*vals), f(*vals))
3758*da0073e9SAndroid Build Coastguard Worker
3759*da0073e9SAndroid Build Coastguard Worker        nf.graph.set_codegen(ListCodeGen())
3760*da0073e9SAndroid Build Coastguard Worker        nf.recompile()
3761*da0073e9SAndroid Build Coastguard Worker
3762*da0073e9SAndroid Build Coastguard Worker        bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3763*da0073e9SAndroid Build Coastguard Worker        bare_fx.graph.set_codegen(CodeGen())
3764*da0073e9SAndroid Build Coastguard Worker        bare_fx.recompile()
3765*da0073e9SAndroid Build Coastguard Worker
3766*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(vals), f(*vals))
3767*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
3768*da0073e9SAndroid Build Coastguard Worker
3769*da0073e9SAndroid Build Coastguard Worker        ts_f = torch.jit.script(nf)
3770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(vals), ts_f(vals))
3771*da0073e9SAndroid Build Coastguard Worker
3772*da0073e9SAndroid Build Coastguard Worker    def test_custom_codegen_with_transformer(self):
3773*da0073e9SAndroid Build Coastguard Worker        class ListCodeGen(CodeGen):
3774*da0073e9SAndroid Build Coastguard Worker            def gen_fn_def(self, free_vars, maybe_return_annotation):
3775*da0073e9SAndroid Build Coastguard Worker                lst_unpack = f"""
3776*da0073e9SAndroid Build Coastguard Workerdef forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3777*da0073e9SAndroid Build Coastguard Worker    {', '.join(free_vars)} = args_list"""
3778*da0073e9SAndroid Build Coastguard Worker                return lst_unpack
3779*da0073e9SAndroid Build Coastguard Worker
3780*da0073e9SAndroid Build Coastguard Worker            def additional_globals(self):
3781*da0073e9SAndroid Build Coastguard Worker                return [('List', typing.List)]
3782*da0073e9SAndroid Build Coastguard Worker
3783*da0073e9SAndroid Build Coastguard Worker            def process_inputs(self, *inputs):
3784*da0073e9SAndroid Build Coastguard Worker                assert len(inputs) == 1
3785*da0073e9SAndroid Build Coastguard Worker                return inputs[0]
3786*da0073e9SAndroid Build Coastguard Worker
3787*da0073e9SAndroid Build Coastguard Worker        def f(a, b):
3788*da0073e9SAndroid Build Coastguard Worker            return a + b
3789*da0073e9SAndroid Build Coastguard Worker
3790*da0073e9SAndroid Build Coastguard Worker        nf = symbolic_trace(f)
3791*da0073e9SAndroid Build Coastguard Worker        vals = [torch.randn(3), torch.randn(3)]
3792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(*vals), f(*vals))
3793*da0073e9SAndroid Build Coastguard Worker
3794*da0073e9SAndroid Build Coastguard Worker        nf.graph.set_codegen(ListCodeGen())
3795*da0073e9SAndroid Build Coastguard Worker        nf.recompile()
3796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(vals), f(*vals))
3797*da0073e9SAndroid Build Coastguard Worker
3798*da0073e9SAndroid Build Coastguard Worker        transformed_gm = Transformer(nf).transform()
3799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nf(vals), transformed_gm(vals))
3800*da0073e9SAndroid Build Coastguard Worker
3801*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_with_codegen(self):
3802*da0073e9SAndroid Build Coastguard Worker        class ListCodeGen(CodeGen):
3803*da0073e9SAndroid Build Coastguard Worker            def gen_fn_def(self, free_vars, maybe_return_annotation):
3804*da0073e9SAndroid Build Coastguard Worker                lst_unpack = f"""
3805*da0073e9SAndroid Build Coastguard Workerdef forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3806*da0073e9SAndroid Build Coastguard Worker    {', '.join(free_vars)} = args_list"""
3807*da0073e9SAndroid Build Coastguard Worker                return lst_unpack
3808*da0073e9SAndroid Build Coastguard Worker
3809*da0073e9SAndroid Build Coastguard Worker            def additional_globals(self):
3810*da0073e9SAndroid Build Coastguard Worker                return [('List', typing.List)]
3811*da0073e9SAndroid Build Coastguard Worker
3812*da0073e9SAndroid Build Coastguard Worker            def process_inputs(self, *inputs):
3813*da0073e9SAndroid Build Coastguard Worker                assert len(inputs) == 1
3814*da0073e9SAndroid Build Coastguard Worker                return inputs[0]
3815*da0073e9SAndroid Build Coastguard Worker
3816*da0073e9SAndroid Build Coastguard Worker            def generate_output(self, output_args):
3817*da0073e9SAndroid Build Coastguard Worker                return f'return list({repr(output_args)})'
3818*da0073e9SAndroid Build Coastguard Worker
3819*da0073e9SAndroid Build Coastguard Worker            def process_outputs(self, outputs):
3820*da0073e9SAndroid Build Coastguard Worker                return list(outputs)
3821*da0073e9SAndroid Build Coastguard Worker
3822*da0073e9SAndroid Build Coastguard Worker        def f(a, b):
3823*da0073e9SAndroid Build Coastguard Worker            a = a + b
3824*da0073e9SAndroid Build Coastguard Worker            b = a + b
3825*da0073e9SAndroid Build Coastguard Worker            return a, b
3826*da0073e9SAndroid Build Coastguard Worker
3827*da0073e9SAndroid Build Coastguard Worker        nf = symbolic_trace(f)
3828*da0073e9SAndroid Build Coastguard Worker        vals = [torch.randn(3), torch.randn(3)]
3829*da0073e9SAndroid Build Coastguard Worker        nf.graph.set_codegen(ListCodeGen())
3830*da0073e9SAndroid Build Coastguard Worker        nf.recompile()
3831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Interpreter(nf).run(vals), nf(vals))
3832*da0073e9SAndroid Build Coastguard Worker
3833*da0073e9SAndroid Build Coastguard Worker    def test_imul_code_print(self):
3834*da0073e9SAndroid Build Coastguard Worker        graph = torch.fx.Graph()
3835*da0073e9SAndroid Build Coastguard Worker        a = graph.placeholder("a")
3836*da0073e9SAndroid Build Coastguard Worker        b = graph.placeholder("b")
3837*da0073e9SAndroid Build Coastguard Worker        graph.call_function(operator.imul, (a, b), {})
3838*da0073e9SAndroid Build Coastguard Worker        graph.output(a)
3839*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.GraphModule({}, graph)
3840*da0073e9SAndroid Build Coastguard Worker        gm.recompile()
3841*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(2, 3), 6)
3842*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a *= b", gm.code)
3843*da0073e9SAndroid Build Coastguard Worker
3844*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_tracer(self):
3845*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
3846*da0073e9SAndroid Build Coastguard Worker            return (x + y).relu().sin()
3847*da0073e9SAndroid Build Coastguard Worker
3848*da0073e9SAndroid Build Coastguard Worker        tracer = Tracer()
3849*da0073e9SAndroid Build Coastguard Worker        tracer_before = copy.deepcopy(tracer)
3850*da0073e9SAndroid Build Coastguard Worker        tracer.trace(fn)
3851*da0073e9SAndroid Build Coastguard Worker        tracer_after = copy.deepcopy(tracer)
3852*da0073e9SAndroid Build Coastguard Worker
3853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(tracer.graph), str(tracer_after.graph))
3854*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
3855*da0073e9SAndroid Build Coastguard Worker
3856*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_graphmodule(self):
3857*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(SimpleTest())
3858*da0073e9SAndroid Build Coastguard Worker        m.meta['hello'] = 'world'
3859*da0073e9SAndroid Build Coastguard Worker        copy_m = copy.deepcopy(m)
3860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(copy_m.meta['hello'], 'world')
3861*da0073e9SAndroid Build Coastguard Worker
3862*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_no_recursion(self):
3863*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(SimpleTest())
3864*da0073e9SAndroid Build Coastguard Worker        m.meta['hello'] = m  # circular reference
3865*da0073e9SAndroid Build Coastguard Worker        copy_m = copy.deepcopy(m)  # finishes
3866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(copy_m), id(copy_m.meta['hello']))
3867*da0073e9SAndroid Build Coastguard Worker
3868*da0073e9SAndroid Build Coastguard Worker    def test_enum(self):
3869*da0073e9SAndroid Build Coastguard Worker        from enum import Enum
3870*da0073e9SAndroid Build Coastguard Worker
3871*da0073e9SAndroid Build Coastguard Worker        class Foo(Enum):
3872*da0073e9SAndroid Build Coastguard Worker            A = 1
3873*da0073e9SAndroid Build Coastguard Worker            B = 2
3874*da0073e9SAndroid Build Coastguard Worker
3875*da0073e9SAndroid Build Coastguard Worker        def leaf_fn(arr, enum_val):
3876*da0073e9SAndroid Build Coastguard Worker            # Use the raw enum.
3877*da0073e9SAndroid Build Coastguard Worker            arr.append(enum_val)
3878*da0073e9SAndroid Build Coastguard Worker            return arr[-1].value
3879*da0073e9SAndroid Build Coastguard Worker
3880*da0073e9SAndroid Build Coastguard Worker        def foo(x):
3881*da0073e9SAndroid Build Coastguard Worker            # Pass the enum as argument.
3882*da0073e9SAndroid Build Coastguard Worker            return leaf_fn(x, Foo.A)
3883*da0073e9SAndroid Build Coastguard Worker
3884*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(foo)
3885*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo([]), traced([]))
3886*da0073e9SAndroid Build Coastguard Worker
3887*da0073e9SAndroid Build Coastguard Worker    def test_insert_arg(self):
3888*da0073e9SAndroid Build Coastguard Worker        m = symbolic_trace(SimpleTest())
3889*da0073e9SAndroid Build Coastguard Worker        m.buf = torch.nn.Buffer(torch.tensor(0))
3890*da0073e9SAndroid Build Coastguard Worker        output_node = next(iter(reversed(m.graph.nodes)))
3891*da0073e9SAndroid Build Coastguard Worker        with m.graph.inserting_before(output_node):
3892*da0073e9SAndroid Build Coastguard Worker            a = m.graph.get_attr("buf")
3893*da0073e9SAndroid Build Coastguard Worker        r = len(output_node.args)
3894*da0073e9SAndroid Build Coastguard Worker        output_node.insert_arg(0, a)
3895*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(output_node.args), r + 1)
3896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(a.users), 1)
3897*da0073e9SAndroid Build Coastguard Worker        self.assertIs(output_node.args[0], a)
3898*da0073e9SAndroid Build Coastguard Worker        self.assertIs(next(iter(a.users.keys())), output_node)
3899*da0073e9SAndroid Build Coastguard Worker        output_node.insert_arg(2, a)
3900*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(output_node.args), r + 2)
3901*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(a.users), 1)
3902*da0073e9SAndroid Build Coastguard Worker        self.assertIs(output_node.args[2], a)
3903*da0073e9SAndroid Build Coastguard Worker        self.assertIs(next(iter(a.users.keys())), output_node)
3904*da0073e9SAndroid Build Coastguard Worker        m.graph.lint()
3905*da0073e9SAndroid Build Coastguard Worker
3906*da0073e9SAndroid Build Coastguard Worker    def test_delete_unused_values(self):
3907*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.proxy_tensor import make_fx
3908*da0073e9SAndroid Build Coastguard Worker
3909*da0073e9SAndroid Build Coastguard Worker        # disable mutable checking temporarily
3910*da0073e9SAndroid Build Coastguard Worker        orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3911*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = False
3912*da0073e9SAndroid Build Coastguard Worker
3913*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d):
3914*da0073e9SAndroid Build Coastguard Worker            x = a + b
3915*da0073e9SAndroid Build Coastguard Worker            y = c + d
3916*da0073e9SAndroid Build Coastguard Worker            y.copy_(x)
3917*da0073e9SAndroid Build Coastguard Worker            x = torch.relu(x)
3918*da0073e9SAndroid Build Coastguard Worker            return x
3919*da0073e9SAndroid Build Coastguard Worker
3920*da0073e9SAndroid Build Coastguard Worker        a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4))
3921*da0073e9SAndroid Build Coastguard Worker        fx_fn = make_fx(fn)(a, b, c, d)
3922*da0073e9SAndroid Build Coastguard Worker        print(fx_fn)
3923*da0073e9SAndroid Build Coastguard Worker
3924*da0073e9SAndroid Build Coastguard Worker        fx_fn.graph.eliminate_dead_code()
3925*da0073e9SAndroid Build Coastguard Worker        py_code = fx_fn.recompile()
3926*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src)
3927*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("copy_ = None" in py_code.src)
3928*da0073e9SAndroid Build Coastguard Worker
3929*da0073e9SAndroid Build Coastguard Worker        # recorver mutable checking flag
3930*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
3931*da0073e9SAndroid Build Coastguard Worker
3932*da0073e9SAndroid Build Coastguard Workerdef run_getitem_target():
3933*da0073e9SAndroid Build Coastguard Worker    from torch.fx._symbolic_trace import _wrapped_methods_to_patch
3934*da0073e9SAndroid Build Coastguard Worker    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
3935*da0073e9SAndroid Build Coastguard Worker    try:
3936*da0073e9SAndroid Build Coastguard Worker        TestFX().getitem_inner()
3937*da0073e9SAndroid Build Coastguard Worker    finally:
3938*da0073e9SAndroid Build Coastguard Worker        _wrapped_methods_to_patch.pop()
3939*da0073e9SAndroid Build Coastguard Worker
3940*da0073e9SAndroid Build Coastguard Worker
3941*da0073e9SAndroid Build Coastguard Workerclass TestOperatorSignatures(JitTestCase):
3942*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3943*da0073e9SAndroid Build Coastguard Worker        # Checking for mutable operations whil tracing is feature flagged
3944*da0073e9SAndroid Build Coastguard Worker        # Enable it in testing but not by default
3945*da0073e9SAndroid Build Coastguard Worker        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3946*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = True
3947*da0073e9SAndroid Build Coastguard Worker
3948*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
3949*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3950*da0073e9SAndroid Build Coastguard Worker
3951*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3952*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float,))
3953*da0073e9SAndroid Build Coastguard Worker    def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
3954*da0073e9SAndroid Build Coastguard Worker        if not isinstance(op.op, types.BuiltinFunctionType):
3955*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("This path doesn't work on Python functions")
3956*da0073e9SAndroid Build Coastguard Worker        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
3957*da0073e9SAndroid Build Coastguard Worker        schemas = get_signature_for_torch_op(op.op)
3958*da0073e9SAndroid Build Coastguard Worker        if not schemas:
3959*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError('No Schemas Returned')
3960*da0073e9SAndroid Build Coastguard Worker        for sample_input in sample_inputs_itr:
3961*da0073e9SAndroid Build Coastguard Worker            # Iterate through overloads until we hit a match. If we exit this
3962*da0073e9SAndroid Build Coastguard Worker            # loop via `else`, we haven't found a match
3963*da0073e9SAndroid Build Coastguard Worker            for schema in schemas:
3964*da0073e9SAndroid Build Coastguard Worker                try:
3965*da0073e9SAndroid Build Coastguard Worker                    bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
3966*da0073e9SAndroid Build Coastguard Worker                    bound_args.apply_defaults()
3967*da0073e9SAndroid Build Coastguard Worker                    op(*bound_args.args, **bound_args.kwargs)
3968*da0073e9SAndroid Build Coastguard Worker                    break
3969*da0073e9SAndroid Build Coastguard Worker                except TypeError as e:
3970*da0073e9SAndroid Build Coastguard Worker                    pass
3971*da0073e9SAndroid Build Coastguard Worker            else:
3972*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'Did not match any schemas for op {op.name}!')
3973*da0073e9SAndroid Build Coastguard Worker
3974*da0073e9SAndroid Build Coastguard Worker
3975*da0073e9SAndroid Build Coastguard Workerclass TestFXAPIBackwardCompatibility(JitTestCase):
3976*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3977*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3978*da0073e9SAndroid Build Coastguard Worker        self.maxDiff = None
3979*da0073e9SAndroid Build Coastguard Worker
3980*da0073e9SAndroid Build Coastguard Worker        # Checking for mutable operations whil tracing is feature flagged
3981*da0073e9SAndroid Build Coastguard Worker        # Enable it in testing but not by default
3982*da0073e9SAndroid Build Coastguard Worker        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3983*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = True
3984*da0073e9SAndroid Build Coastguard Worker
3985*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
3986*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
3987*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3988*da0073e9SAndroid Build Coastguard Worker
3989*da0073e9SAndroid Build Coastguard Worker
3990*da0073e9SAndroid Build Coastguard Worker    def _fn_to_stable_annotation_str(self, obj):
3991*da0073e9SAndroid Build Coastguard Worker        """
3992*da0073e9SAndroid Build Coastguard Worker        Unfortunately we have to serialize function signatures manually since
3993*da0073e9SAndroid Build Coastguard Worker        serialization for `inspect.Signature` objects is not stable across
3994*da0073e9SAndroid Build Coastguard Worker        python versions
3995*da0073e9SAndroid Build Coastguard Worker        """
3996*da0073e9SAndroid Build Coastguard Worker        fn_name = torch.typename(obj)
3997*da0073e9SAndroid Build Coastguard Worker
3998*da0073e9SAndroid Build Coastguard Worker        signature = inspect.signature(obj)
3999*da0073e9SAndroid Build Coastguard Worker
4000*da0073e9SAndroid Build Coastguard Worker        sig_str = f'{fn_name}{signature}'
4001*da0073e9SAndroid Build Coastguard Worker
4002*da0073e9SAndroid Build Coastguard Worker        arg_strs = []
4003*da0073e9SAndroid Build Coastguard Worker        for k, v in signature.parameters.items():
4004*da0073e9SAndroid Build Coastguard Worker            maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
4005*da0073e9SAndroid Build Coastguard Worker                if v.annotation is not inspect.Signature.empty else ''
4006*da0073e9SAndroid Build Coastguard Worker
4007*da0073e9SAndroid Build Coastguard Worker            def default_val_str(val):
4008*da0073e9SAndroid Build Coastguard Worker                if isinstance(val, (tuple, list)):
4009*da0073e9SAndroid Build Coastguard Worker                    str_pieces = ['(' if isinstance(val, tuple) else '[']
4010*da0073e9SAndroid Build Coastguard Worker                    str_pieces.append(', '.join(default_val_str(v) for v in val))
4011*da0073e9SAndroid Build Coastguard Worker                    if isinstance(val, tuple) and len(str_pieces) == 2:
4012*da0073e9SAndroid Build Coastguard Worker                        str_pieces.append(',')
4013*da0073e9SAndroid Build Coastguard Worker                    str_pieces.append(')' if isinstance(val, tuple) else ']')
4014*da0073e9SAndroid Build Coastguard Worker                    return ''.join(str_pieces)
4015*da0073e9SAndroid Build Coastguard Worker
4016*da0073e9SAndroid Build Coastguard Worker                # Need to fix up some default value strings.
4017*da0073e9SAndroid Build Coastguard Worker                # First case: modules. Default module `repr` contains the FS path of the module.
4018*da0073e9SAndroid Build Coastguard Worker                # Don't leak that
4019*da0073e9SAndroid Build Coastguard Worker                if isinstance(val, types.ModuleType):
4020*da0073e9SAndroid Build Coastguard Worker                    return f'<module {val.__name__}>'
4021*da0073e9SAndroid Build Coastguard Worker
4022*da0073e9SAndroid Build Coastguard Worker                # Second case: callables. Callables (such as lambdas) encode their address in
4023*da0073e9SAndroid Build Coastguard Worker                # their string repr. Don't do that
4024*da0073e9SAndroid Build Coastguard Worker                if callable(val):
4025*da0073e9SAndroid Build Coastguard Worker                    return f'<function {val.__name__}>'
4026*da0073e9SAndroid Build Coastguard Worker
4027*da0073e9SAndroid Build Coastguard Worker                return str(val)
4028*da0073e9SAndroid Build Coastguard Worker
4029*da0073e9SAndroid Build Coastguard Worker            if v.default is not inspect.Signature.empty:
4030*da0073e9SAndroid Build Coastguard Worker                default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
4031*da0073e9SAndroid Build Coastguard Worker                maybe_default = f' = {default_val_str}'
4032*da0073e9SAndroid Build Coastguard Worker            else:
4033*da0073e9SAndroid Build Coastguard Worker                maybe_default = ''
4034*da0073e9SAndroid Build Coastguard Worker            maybe_stars = ''
4035*da0073e9SAndroid Build Coastguard Worker            if v.kind == inspect.Parameter.VAR_POSITIONAL:
4036*da0073e9SAndroid Build Coastguard Worker                maybe_stars = '*'
4037*da0073e9SAndroid Build Coastguard Worker            elif v.kind == inspect.Parameter.VAR_KEYWORD:
4038*da0073e9SAndroid Build Coastguard Worker                maybe_stars = '**'
4039*da0073e9SAndroid Build Coastguard Worker            arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
4040*da0073e9SAndroid Build Coastguard Worker
4041*da0073e9SAndroid Build Coastguard Worker        return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
4042*da0073e9SAndroid Build Coastguard Worker            if signature.return_annotation is not inspect.Signature.empty else ''
4043*da0073e9SAndroid Build Coastguard Worker
4044*da0073e9SAndroid Build Coastguard Worker        return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
4045*da0073e9SAndroid Build Coastguard Worker
4046*da0073e9SAndroid Build Coastguard Worker    def _annotation_type_to_stable_str(self, t, sig_str):
4047*da0073e9SAndroid Build Coastguard Worker        if t is inspect.Signature.empty:
4048*da0073e9SAndroid Build Coastguard Worker            return ''
4049*da0073e9SAndroid Build Coastguard Worker
4050*da0073e9SAndroid Build Coastguard Worker        # Forward ref
4051*da0073e9SAndroid Build Coastguard Worker        if isinstance(t, str):
4052*da0073e9SAndroid Build Coastguard Worker            return f"'{t}'"
4053*da0073e9SAndroid Build Coastguard Worker        if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
4054*da0073e9SAndroid Build Coastguard Worker            return t.__forward_arg__
4055*da0073e9SAndroid Build Coastguard Worker        if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
4056*da0073e9SAndroid Build Coastguard Worker            return t.__forward_arg__
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker        trivial_mappings = {
4059*da0073e9SAndroid Build Coastguard Worker            str : 'str',
4060*da0073e9SAndroid Build Coastguard Worker            int : 'int',
4061*da0073e9SAndroid Build Coastguard Worker            float: 'float',
4062*da0073e9SAndroid Build Coastguard Worker            bool: 'bool',
4063*da0073e9SAndroid Build Coastguard Worker            torch.dtype: 'torch.dtype',
4064*da0073e9SAndroid Build Coastguard Worker            torch.Tensor: 'torch.Tensor',
4065*da0073e9SAndroid Build Coastguard Worker            torch.device: 'torch.device',
4066*da0073e9SAndroid Build Coastguard Worker            torch.memory_format: 'torch.memory_format',
4067*da0073e9SAndroid Build Coastguard Worker            slice: 'slice',
4068*da0073e9SAndroid Build Coastguard Worker            torch.nn.Module: 'torch.nn.modules.module.Module',
4069*da0073e9SAndroid Build Coastguard Worker            torch.fx.Graph : 'torch.fx.graph.Graph',
4070*da0073e9SAndroid Build Coastguard Worker            torch.fx.Node : 'torch.fx.node.Node',
4071*da0073e9SAndroid Build Coastguard Worker            torch.fx.Proxy : 'torch.fx.proxy.Proxy',
4072*da0073e9SAndroid Build Coastguard Worker            torch.fx.node.Target : 'torch.fx.node.Target',
4073*da0073e9SAndroid Build Coastguard Worker            torch.fx.node.Argument : 'torch.fx.node.Argument',
4074*da0073e9SAndroid Build Coastguard Worker            torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
4075*da0073e9SAndroid Build Coastguard Worker            torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
4076*da0073e9SAndroid Build Coastguard Worker            torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
4077*da0073e9SAndroid Build Coastguard Worker            Ellipsis : '...',
4078*da0073e9SAndroid Build Coastguard Worker            typing.Any: 'Any',
4079*da0073e9SAndroid Build Coastguard Worker            type(None): 'NoneType',
4080*da0073e9SAndroid Build Coastguard Worker            None: 'None',
4081*da0073e9SAndroid Build Coastguard Worker            typing.Iterator: 'Iterator',
4082*da0073e9SAndroid Build Coastguard Worker        }
4083*da0073e9SAndroid Build Coastguard Worker
4084*da0073e9SAndroid Build Coastguard Worker        mapping = trivial_mappings.get(t, None)
4085*da0073e9SAndroid Build Coastguard Worker        if mapping:
4086*da0073e9SAndroid Build Coastguard Worker            return mapping
4087*da0073e9SAndroid Build Coastguard Worker
4088*da0073e9SAndroid Build Coastguard Worker        # Handle types with contained types
4089*da0073e9SAndroid Build Coastguard Worker        contained = getattr(t, '__args__', None) or []
4090*da0073e9SAndroid Build Coastguard Worker
4091*da0073e9SAndroid Build Coastguard Worker        # Callables contain a bare List for arguments
4092*da0073e9SAndroid Build Coastguard Worker        contained = t if isinstance(t, list) else contained
4093*da0073e9SAndroid Build Coastguard Worker
4094*da0073e9SAndroid Build Coastguard Worker        # Python 3.8 puts type vars into __args__ for unbound types such as Dict
4095*da0073e9SAndroid Build Coastguard Worker        if all(isinstance(ct, typing.TypeVar) for ct in contained):
4096*da0073e9SAndroid Build Coastguard Worker            contained = []
4097*da0073e9SAndroid Build Coastguard Worker
4098*da0073e9SAndroid Build Coastguard Worker        contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
4099*da0073e9SAndroid Build Coastguard Worker        contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
4100*da0073e9SAndroid Build Coastguard Worker
4101*da0073e9SAndroid Build Coastguard Worker
4102*da0073e9SAndroid Build Coastguard Worker        origin = getattr(t, '__origin__', None)
4103*da0073e9SAndroid Build Coastguard Worker        if origin is None:
4104*da0073e9SAndroid Build Coastguard Worker            # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
4105*da0073e9SAndroid Build Coastguard Worker            origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
4106*da0073e9SAndroid Build Coastguard Worker
4107*da0073e9SAndroid Build Coastguard Worker        if origin in {tuple, typing.Tuple}:
4108*da0073e9SAndroid Build Coastguard Worker            return f'Tuple{contained_type_str}'
4109*da0073e9SAndroid Build Coastguard Worker        if origin in {typing.Union}:
4110*da0073e9SAndroid Build Coastguard Worker            # Annoying hack to detect Optional
4111*da0073e9SAndroid Build Coastguard Worker            if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
4112*da0073e9SAndroid Build Coastguard Worker                not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
4113*da0073e9SAndroid Build Coastguard Worker                return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
4114*da0073e9SAndroid Build Coastguard Worker            return f'Union{contained_type_str}'
4115*da0073e9SAndroid Build Coastguard Worker        if origin in {dict, typing.Dict}:
4116*da0073e9SAndroid Build Coastguard Worker            return f'Dict{contained_type_str}'
4117*da0073e9SAndroid Build Coastguard Worker        if origin in {list, typing.List}:
4118*da0073e9SAndroid Build Coastguard Worker            return f'List{contained_type_str}'
4119*da0073e9SAndroid Build Coastguard Worker        if origin in {type, typing.Type}:
4120*da0073e9SAndroid Build Coastguard Worker            return f'Type{contained_type_str}'
4121*da0073e9SAndroid Build Coastguard Worker        if isinstance(t, typing.Callable):
4122*da0073e9SAndroid Build Coastguard Worker            if len(contained) > 0 and contained[0] is not Ellipsis:
4123*da0073e9SAndroid Build Coastguard Worker                return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
4124*da0073e9SAndroid Build Coastguard Worker            else:
4125*da0073e9SAndroid Build Coastguard Worker                return f'Callable{contained_type_str}'
4126*da0073e9SAndroid Build Coastguard Worker
4127*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
4128*da0073e9SAndroid Build Coastguard Worker                           f'Please add support for this type and confirm with the '
4129*da0073e9SAndroid Build Coastguard Worker                           f'FX team that your signature change is valid.')
4130*da0073e9SAndroid Build Coastguard Worker
4131*da0073e9SAndroid Build Coastguard Worker
4132*da0073e9SAndroid Build Coastguard Worker    def test_function_back_compat(self):
4133*da0073e9SAndroid Build Coastguard Worker        """
4134*da0073e9SAndroid Build Coastguard Worker        Test backward compatibility for function signatures with
4135*da0073e9SAndroid Build Coastguard Worker        @compatibility(is_backward_compatible=True). Currently this checks for
4136*da0073e9SAndroid Build Coastguard Worker        exact signature matches, which may lead to false positives. If this
4137*da0073e9SAndroid Build Coastguard Worker        becomes too annoying, we can refine this check to actually parse out
4138*da0073e9SAndroid Build Coastguard Worker        the saved schema strings and check if the change is truly backward-
4139*da0073e9SAndroid Build Coastguard Worker        incompatible.
4140*da0073e9SAndroid Build Coastguard Worker        """
4141*da0073e9SAndroid Build Coastguard Worker        signature_strs = []
4142*da0073e9SAndroid Build Coastguard Worker
4143*da0073e9SAndroid Build Coastguard Worker        for obj in _BACK_COMPAT_OBJECTS:
4144*da0073e9SAndroid Build Coastguard Worker            if not isinstance(obj, type):
4145*da0073e9SAndroid Build Coastguard Worker                signature_strs.append(self._fn_to_stable_annotation_str(obj))
4146*da0073e9SAndroid Build Coastguard Worker
4147*da0073e9SAndroid Build Coastguard Worker        signature_strs.sort()
4148*da0073e9SAndroid Build Coastguard Worker
4149*da0073e9SAndroid Build Coastguard Worker        try:
4150*da0073e9SAndroid Build Coastguard Worker            self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures')
4151*da0073e9SAndroid Build Coastguard Worker        except AssertionError as e:
4152*da0073e9SAndroid Build Coastguard Worker            msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
4153*da0073e9SAndroid Build Coastguard Worker                  f"as backwards-compatible has experienced a signature change. See the " \
4154*da0073e9SAndroid Build Coastguard Worker                  f"above exception context for more information. If this change was " \
4155*da0073e9SAndroid Build Coastguard Worker                  f"unintended, please revert it. If it was intended, check with the FX " \
4156*da0073e9SAndroid Build Coastguard Worker                  f"team to ensure that the proper deprecation protocols have been followed " \
4157*da0073e9SAndroid Build Coastguard Worker                  f"and subsequently --accept the change."
4158*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(msg)  # noqa: B904
4159*da0073e9SAndroid Build Coastguard Worker
4160*da0073e9SAndroid Build Coastguard Worker    def test_class_member_back_compat(self):
4161*da0073e9SAndroid Build Coastguard Worker        """
4162*da0073e9SAndroid Build Coastguard Worker        Test backward compatibility for members of classes with
4163*da0073e9SAndroid Build Coastguard Worker        @compatibility(is_backward_compatible=True). Currently this checks for
4164*da0073e9SAndroid Build Coastguard Worker        exact matches on the publicly visible members of the class.
4165*da0073e9SAndroid Build Coastguard Worker        """
4166*da0073e9SAndroid Build Coastguard Worker        class_method_strs = []
4167*da0073e9SAndroid Build Coastguard Worker
4168*da0073e9SAndroid Build Coastguard Worker        for obj in _BACK_COMPAT_OBJECTS:
4169*da0073e9SAndroid Build Coastguard Worker            if isinstance(obj, type):
4170*da0073e9SAndroid Build Coastguard Worker                public_members = [name for name in obj.__dict__ if not name.startswith('_')]
4171*da0073e9SAndroid Build Coastguard Worker                class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
4172*da0073e9SAndroid Build Coastguard Worker
4173*da0073e9SAndroid Build Coastguard Worker        class_method_strs.sort()
4174*da0073e9SAndroid Build Coastguard Worker
4175*da0073e9SAndroid Build Coastguard Worker        try:
4176*da0073e9SAndroid Build Coastguard Worker            self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
4177*da0073e9SAndroid Build Coastguard Worker        except AssertionError as e:
4178*da0073e9SAndroid Build Coastguard Worker            msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
4179*da0073e9SAndroid Build Coastguard Worker                  f"as backwards-compatible has experienced change in its public members. See the " \
4180*da0073e9SAndroid Build Coastguard Worker                  f"above exception context for more information. If this change was " \
4181*da0073e9SAndroid Build Coastguard Worker                  f"unintended, please revert it. If it was intended, check with the FX " \
4182*da0073e9SAndroid Build Coastguard Worker                  f"team to ensure that the proper deprecation protocols have been followed " \
4183*da0073e9SAndroid Build Coastguard Worker                  f"and subsequently --accept the change."
4184*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(msg) from e
4185*da0073e9SAndroid Build Coastguard Worker
4186*da0073e9SAndroid Build Coastguard Worker    def test_public_api_surface(self):
4187*da0073e9SAndroid Build Coastguard Worker        non_back_compat_objects = {}
4188*da0073e9SAndroid Build Coastguard Worker
4189*da0073e9SAndroid Build Coastguard Worker        def check_symbols_have_bc_designation(m, seen):
4190*da0073e9SAndroid Build Coastguard Worker            if not m.__name__.startswith('torch.fx'):
4191*da0073e9SAndroid Build Coastguard Worker                return
4192*da0073e9SAndroid Build Coastguard Worker            if m.__name__.startswith('torch.fx.experimental'):
4193*da0073e9SAndroid Build Coastguard Worker                return
4194*da0073e9SAndroid Build Coastguard Worker            # It's really common for inner functions to point to random modules
4195*da0073e9SAndroid Build Coastguard Worker            # - make sure we don't recurse into modules we've already checked.
4196*da0073e9SAndroid Build Coastguard Worker            seen.add(m.__name__)
4197*da0073e9SAndroid Build Coastguard Worker            for k, v in m.__dict__.items():
4198*da0073e9SAndroid Build Coastguard Worker                if hasattr(v, '__name__') and v.__name__ in seen:
4199*da0073e9SAndroid Build Coastguard Worker                    continue
4200*da0073e9SAndroid Build Coastguard Worker                if v is m:
4201*da0073e9SAndroid Build Coastguard Worker                    continue
4202*da0073e9SAndroid Build Coastguard Worker                if k.startswith('_'):
4203*da0073e9SAndroid Build Coastguard Worker                    continue
4204*da0073e9SAndroid Build Coastguard Worker                if isinstance(v, types.ModuleType):
4205*da0073e9SAndroid Build Coastguard Worker                    check_symbols_have_bc_designation(v, seen)
4206*da0073e9SAndroid Build Coastguard Worker                elif isinstance(v, (type, types.FunctionType)):
4207*da0073e9SAndroid Build Coastguard Worker                    if v not in _MARKED_WITH_COMPATIBILITY:
4208*da0073e9SAndroid Build Coastguard Worker                        non_back_compat_objects.setdefault(v)
4209*da0073e9SAndroid Build Coastguard Worker
4210*da0073e9SAndroid Build Coastguard Worker        check_symbols_have_bc_designation(torch.fx, set())
4211*da0073e9SAndroid Build Coastguard Worker        check_symbols_have_bc_designation(torch.fx.passes, set())
4212*da0073e9SAndroid Build Coastguard Worker
4213*da0073e9SAndroid Build Coastguard Worker        non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
4214*da0073e9SAndroid Build Coastguard Worker        # Only want objects in torch.fx
4215*da0073e9SAndroid Build Coastguard Worker        non_back_compat_strs = [
4216*da0073e9SAndroid Build Coastguard Worker            s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
4217*da0073e9SAndroid Build Coastguard Worker        # Only want objects in public namespaces
4218*da0073e9SAndroid Build Coastguard Worker        non_back_compat_strs = [
4219*da0073e9SAndroid Build Coastguard Worker            s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
4220*da0073e9SAndroid Build Coastguard Worker        non_back_compat_strs.sort()
4221*da0073e9SAndroid Build Coastguard Worker
4222*da0073e9SAndroid Build Coastguard Worker        if len(non_back_compat_strs) != 0:
4223*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
4224*da0073e9SAndroid Build Coastguard Worker                                 f"backwards-compatibility classification! Please decorate these "
4225*da0073e9SAndroid Build Coastguard Worker                                 f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
4226*da0073e9SAndroid Build Coastguard Worker                                 f"BC guarantees.")
4227*da0073e9SAndroid Build Coastguard Worker
4228*da0073e9SAndroid Build Coastguard Worker    def test_adding_side_effect_function(self):
4229*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
4230*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4231*da0073e9SAndroid Build Coastguard Worker                side_effect_func(x)
4232*da0073e9SAndroid Build Coastguard Worker                return x
4233*da0073e9SAndroid Build Coastguard Worker
4234*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(TestModule())
4235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(gm.graph.nodes), 3)
4236*da0073e9SAndroid Build Coastguard Worker        gm.graph.eliminate_dead_code()
4237*da0073e9SAndroid Build Coastguard Worker        gm.recompile()
4238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(gm.graph.nodes), 3)
4239*da0073e9SAndroid Build Coastguard Worker        found = False
4240*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
4241*da0073e9SAndroid Build Coastguard Worker            if node.op == 'call_function' and node.target == side_effect_func:
4242*da0073e9SAndroid Build Coastguard Worker                found = True
4243*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found)
4244*da0073e9SAndroid Build Coastguard Worker
4245*da0073e9SAndroid Build Coastguard Worker    def test_preserve_unused_attr_after_unpickle(self):
4246*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(Add())
4247*da0073e9SAndroid Build Coastguard Worker        gm.add_submodule("foo", Add())
4248*da0073e9SAndroid Build Coastguard Worker        gm.dummy_buffer = torch.nn.Buffer(torch.empty(1))
4249*da0073e9SAndroid Build Coastguard Worker        gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1)))
4250*da0073e9SAndroid Build Coastguard Worker        b = io.BytesIO()
4251*da0073e9SAndroid Build Coastguard Worker        torch.save(gm, b)
4252*da0073e9SAndroid Build Coastguard Worker        b.seek(0)
4253*da0073e9SAndroid Build Coastguard Worker        # weights_only=False as this loads a GraphModule
4254*da0073e9SAndroid Build Coastguard Worker        # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
4255*da0073e9SAndroid Build Coastguard Worker        reload_gm = torch.load(b, weights_only=False)
4256*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(reload_gm, "foo"))
4257*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
4258*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
4259*da0073e9SAndroid Build Coastguard Worker
4260*da0073e9SAndroid Build Coastguard Worker# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454
4261*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
4262*da0073e9SAndroid Build Coastguard Worker    sys.version_info >= (3, 12), "Failing on python 3.12+"
4263*da0073e9SAndroid Build Coastguard Worker)
4264*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalTracing(JitTestCase):
4265*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
4266*da0073e9SAndroid Build Coastguard Worker        super().setUp()
4267*da0073e9SAndroid Build Coastguard Worker        # Checking for mutable operations whil tracing is feature flagged
4268*da0073e9SAndroid Build Coastguard Worker        # Enable it in testing but not by default
4269*da0073e9SAndroid Build Coastguard Worker        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4270*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = True
4271*da0073e9SAndroid Build Coastguard Worker
4272*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
4273*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
4274*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4275*da0073e9SAndroid Build Coastguard Worker
4276*da0073e9SAndroid Build Coastguard Worker    IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
4277*da0073e9SAndroid Build Coastguard Worker                    "has_torch_function_variadic", "handle_torch_function",
4278*da0073e9SAndroid Build Coastguard Worker                    "boolean_dispatch")
4279*da0073e9SAndroid Build Coastguard Worker    TO_PATCH = {"has_torch_function": None,
4280*da0073e9SAndroid Build Coastguard Worker                "has_torch_function_unary": None,
4281*da0073e9SAndroid Build Coastguard Worker                "has_torch_function_variadic": None}
4282*da0073e9SAndroid Build Coastguard Worker
4283*da0073e9SAndroid Build Coastguard Worker    BUILT_IN_FUNC = (AssertionError, "")
4284*da0073e9SAndroid Build Coastguard Worker    PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
4285*da0073e9SAndroid Build Coastguard Worker    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4286*da0073e9SAndroid Build Coastguard Worker    LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
4287*da0073e9SAndroid Build Coastguard Worker    ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
4288*da0073e9SAndroid Build Coastguard Worker    CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
4289*da0073e9SAndroid Build Coastguard Worker    INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
4290*da0073e9SAndroid Build Coastguard Worker    MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
4291*da0073e9SAndroid Build Coastguard Worker
4292*da0073e9SAndroid Build Coastguard Worker    UNTRACEABLE_FUNCTIONALS = {
4293*da0073e9SAndroid Build Coastguard Worker        "adaptive_avg_pool1d": BUILT_IN_FUNC,
4294*da0073e9SAndroid Build Coastguard Worker        "avg_pool1d": BUILT_IN_FUNC,
4295*da0073e9SAndroid Build Coastguard Worker        "avg_pool2d": BUILT_IN_FUNC,
4296*da0073e9SAndroid Build Coastguard Worker        "avg_pool3d": BUILT_IN_FUNC,
4297*da0073e9SAndroid Build Coastguard Worker        "bilinear": BUILT_IN_FUNC,
4298*da0073e9SAndroid Build Coastguard Worker        "celu_": BUILT_IN_FUNC,
4299*da0073e9SAndroid Build Coastguard Worker        "channel_shuffle": BUILT_IN_FUNC,
4300*da0073e9SAndroid Build Coastguard Worker        "native_channel_shuffle": BUILT_IN_FUNC,
4301*da0073e9SAndroid Build Coastguard Worker        "conv1d": BUILT_IN_FUNC,
4302*da0073e9SAndroid Build Coastguard Worker        "conv2d": BUILT_IN_FUNC,
4303*da0073e9SAndroid Build Coastguard Worker        "conv3d": BUILT_IN_FUNC,
4304*da0073e9SAndroid Build Coastguard Worker        "conv_tbc": BUILT_IN_FUNC,
4305*da0073e9SAndroid Build Coastguard Worker        "conv_transpose1d": BUILT_IN_FUNC,
4306*da0073e9SAndroid Build Coastguard Worker        "conv_transpose2d": BUILT_IN_FUNC,
4307*da0073e9SAndroid Build Coastguard Worker        "conv_transpose3d": BUILT_IN_FUNC,
4308*da0073e9SAndroid Build Coastguard Worker        "cosine_similarity": BUILT_IN_FUNC,
4309*da0073e9SAndroid Build Coastguard Worker        "elu_": BUILT_IN_FUNC,
4310*da0073e9SAndroid Build Coastguard Worker        "gelu": BUILT_IN_FUNC,
4311*da0073e9SAndroid Build Coastguard Worker        "hardshrink": BUILT_IN_FUNC,
4312*da0073e9SAndroid Build Coastguard Worker        "hardtanh_": BUILT_IN_FUNC,
4313*da0073e9SAndroid Build Coastguard Worker        "leaky_relu_": BUILT_IN_FUNC,
4314*da0073e9SAndroid Build Coastguard Worker        "linear": BUILT_IN_FUNC,
4315*da0073e9SAndroid Build Coastguard Worker        "logsigmoid": BUILT_IN_FUNC,
4316*da0073e9SAndroid Build Coastguard Worker        "one_hot": BUILT_IN_FUNC,
4317*da0073e9SAndroid Build Coastguard Worker        "pad": ARG_TYPE_MISMATCH,
4318*da0073e9SAndroid Build Coastguard Worker        "pairwise_distance": BUILT_IN_FUNC,
4319*da0073e9SAndroid Build Coastguard Worker        "pdist": BUILT_IN_FUNC,
4320*da0073e9SAndroid Build Coastguard Worker        "pixel_shuffle": BUILT_IN_FUNC,
4321*da0073e9SAndroid Build Coastguard Worker        "pixel_unshuffle": BUILT_IN_FUNC,
4322*da0073e9SAndroid Build Coastguard Worker        "prelu": BUILT_IN_FUNC,
4323*da0073e9SAndroid Build Coastguard Worker        "relu_": BUILT_IN_FUNC,
4324*da0073e9SAndroid Build Coastguard Worker        "rrelu_": BUILT_IN_FUNC,
4325*da0073e9SAndroid Build Coastguard Worker        "selu_": BUILT_IN_FUNC,
4326*da0073e9SAndroid Build Coastguard Worker        "scaled_dot_product_attention": BUILT_IN_FUNC,
4327*da0073e9SAndroid Build Coastguard Worker        "softplus": BUILT_IN_FUNC,
4328*da0073e9SAndroid Build Coastguard Worker        "softshrink": BUILT_IN_FUNC,
4329*da0073e9SAndroid Build Coastguard Worker        "threshold_": BUILT_IN_FUNC,
4330*da0073e9SAndroid Build Coastguard Worker
4331*da0073e9SAndroid Build Coastguard Worker        "adaptive_avg_pool2d": LEN_ERROR,
4332*da0073e9SAndroid Build Coastguard Worker        "adaptive_avg_pool3d": LEN_ERROR,
4333*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool2d_with_indices": LEN_ERROR,
4334*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool3d_with_indices": LEN_ERROR,
4335*da0073e9SAndroid Build Coastguard Worker        "instance_norm": CONTROL_FLOW,
4336*da0073e9SAndroid Build Coastguard Worker
4337*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool1d": PROXY_ITERABLE,
4338*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool2d": PROXY_ITERABLE,
4339*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool3d": PROXY_ITERABLE,
4340*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool2d": PROXY_ITERABLE,
4341*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool3d": PROXY_ITERABLE,
4342*da0073e9SAndroid Build Coastguard Worker        "max_pool1d": PROXY_ITERABLE,
4343*da0073e9SAndroid Build Coastguard Worker        "max_pool2d": PROXY_ITERABLE,
4344*da0073e9SAndroid Build Coastguard Worker        "max_pool3d": PROXY_ITERABLE,
4345*da0073e9SAndroid Build Coastguard Worker
4346*da0073e9SAndroid Build Coastguard Worker        "lp_pool2d": PROXY_ITERATED,
4347*da0073e9SAndroid Build Coastguard Worker        "lp_pool3d": PROXY_ITERATED,
4348*da0073e9SAndroid Build Coastguard Worker        "max_unpool1d": PROXY_ITERATED,
4349*da0073e9SAndroid Build Coastguard Worker        "max_unpool2d": PROXY_ITERATED,
4350*da0073e9SAndroid Build Coastguard Worker        "max_unpool3d": PROXY_ITERATED,
4351*da0073e9SAndroid Build Coastguard Worker        "fold": PROXY_ITERATED,
4352*da0073e9SAndroid Build Coastguard Worker        "unfold": PROXY_ITERATED,
4353*da0073e9SAndroid Build Coastguard Worker
4354*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4355*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4356*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4357*da0073e9SAndroid Build Coastguard Worker        "layer_norm": ARG_TYPE_MISMATCH,
4358*da0073e9SAndroid Build Coastguard Worker        "rms_norm": ARG_TYPE_MISMATCH,
4359*da0073e9SAndroid Build Coastguard Worker        "lp_pool1d": ARG_TYPE_MISMATCH,
4360*da0073e9SAndroid Build Coastguard Worker
4361*da0073e9SAndroid Build Coastguard Worker        "affine_grid": CONTROL_FLOW,
4362*da0073e9SAndroid Build Coastguard Worker        "alpha_dropout": CONTROL_FLOW,
4363*da0073e9SAndroid Build Coastguard Worker        "batch_norm": CONTROL_FLOW,
4364*da0073e9SAndroid Build Coastguard Worker        "binary_cross_entropy": CONTROL_FLOW,
4365*da0073e9SAndroid Build Coastguard Worker        "binary_cross_entropy_with_logits": CONTROL_FLOW,
4366*da0073e9SAndroid Build Coastguard Worker        "celu": CONTROL_FLOW,
4367*da0073e9SAndroid Build Coastguard Worker        "cosine_embedding_loss": CONTROL_FLOW,
4368*da0073e9SAndroid Build Coastguard Worker        "cross_entropy": CONTROL_FLOW,
4369*da0073e9SAndroid Build Coastguard Worker        "ctc_loss": CONTROL_FLOW,
4370*da0073e9SAndroid Build Coastguard Worker        "dropout": CONTROL_FLOW,
4371*da0073e9SAndroid Build Coastguard Worker        "dropout1d": CONTROL_FLOW,
4372*da0073e9SAndroid Build Coastguard Worker        "dropout2d": CONTROL_FLOW,
4373*da0073e9SAndroid Build Coastguard Worker        "dropout3d": CONTROL_FLOW,
4374*da0073e9SAndroid Build Coastguard Worker        "elu": CONTROL_FLOW,
4375*da0073e9SAndroid Build Coastguard Worker        "embedding": CONTROL_FLOW,
4376*da0073e9SAndroid Build Coastguard Worker        "embedding_bag": CONTROL_FLOW,
4377*da0073e9SAndroid Build Coastguard Worker        "feature_alpha_dropout": CONTROL_FLOW,
4378*da0073e9SAndroid Build Coastguard Worker        "gaussian_nll_loss": CONTROL_FLOW,
4379*da0073e9SAndroid Build Coastguard Worker        "glu": CONTROL_FLOW,
4380*da0073e9SAndroid Build Coastguard Worker        "grid_sample": CONTROL_FLOW,
4381*da0073e9SAndroid Build Coastguard Worker        "group_norm": CONTROL_FLOW,
4382*da0073e9SAndroid Build Coastguard Worker        "gumbel_softmax": CONTROL_FLOW,
4383*da0073e9SAndroid Build Coastguard Worker        "hardsigmoid": CONTROL_FLOW,
4384*da0073e9SAndroid Build Coastguard Worker        "hardswish": CONTROL_FLOW,
4385*da0073e9SAndroid Build Coastguard Worker        "hardtanh": CONTROL_FLOW,
4386*da0073e9SAndroid Build Coastguard Worker        "hinge_embedding_loss": CONTROL_FLOW,
4387*da0073e9SAndroid Build Coastguard Worker        "huber_loss": CONTROL_FLOW,
4388*da0073e9SAndroid Build Coastguard Worker        "interpolate": CONTROL_FLOW,
4389*da0073e9SAndroid Build Coastguard Worker        "kl_div": CONTROL_FLOW,
4390*da0073e9SAndroid Build Coastguard Worker        "l1_loss": CONTROL_FLOW,
4391*da0073e9SAndroid Build Coastguard Worker        "leaky_relu": CONTROL_FLOW,
4392*da0073e9SAndroid Build Coastguard Worker        "local_response_norm": CONTROL_FLOW,
4393*da0073e9SAndroid Build Coastguard Worker        "margin_ranking_loss": CONTROL_FLOW,
4394*da0073e9SAndroid Build Coastguard Worker        "max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4395*da0073e9SAndroid Build Coastguard Worker        "max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4396*da0073e9SAndroid Build Coastguard Worker        "max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4397*da0073e9SAndroid Build Coastguard Worker        "mse_loss": CONTROL_FLOW,
4398*da0073e9SAndroid Build Coastguard Worker        "multi_head_attention_forward": CONTROL_FLOW,
4399*da0073e9SAndroid Build Coastguard Worker        "multi_margin_loss": CONTROL_FLOW,
4400*da0073e9SAndroid Build Coastguard Worker        "multilabel_margin_loss": CONTROL_FLOW,
4401*da0073e9SAndroid Build Coastguard Worker        "multilabel_soft_margin_loss": CONTROL_FLOW,
4402*da0073e9SAndroid Build Coastguard Worker        "nll_loss": CONTROL_FLOW,
4403*da0073e9SAndroid Build Coastguard Worker        "poisson_nll_loss": CONTROL_FLOW,
4404*da0073e9SAndroid Build Coastguard Worker        "relu": CONTROL_FLOW,
4405*da0073e9SAndroid Build Coastguard Worker        "relu6": CONTROL_FLOW,
4406*da0073e9SAndroid Build Coastguard Worker        "rrelu": CONTROL_FLOW,
4407*da0073e9SAndroid Build Coastguard Worker        "selu": CONTROL_FLOW,
4408*da0073e9SAndroid Build Coastguard Worker        "silu": CONTROL_FLOW,
4409*da0073e9SAndroid Build Coastguard Worker        "mish": CONTROL_FLOW,
4410*da0073e9SAndroid Build Coastguard Worker        "smooth_l1_loss": CONTROL_FLOW,
4411*da0073e9SAndroid Build Coastguard Worker        "soft_margin_loss": CONTROL_FLOW,
4412*da0073e9SAndroid Build Coastguard Worker        "threshold": CONTROL_FLOW,
4413*da0073e9SAndroid Build Coastguard Worker        "triplet_margin_loss": CONTROL_FLOW,
4414*da0073e9SAndroid Build Coastguard Worker        "triplet_margin_with_distance_loss": CONTROL_FLOW,
4415*da0073e9SAndroid Build Coastguard Worker        "upsample": CONTROL_FLOW,
4416*da0073e9SAndroid Build Coastguard Worker
4417*da0073e9SAndroid Build Coastguard Worker        "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
4418*da0073e9SAndroid Build Coastguard Worker        "upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
4419*da0073e9SAndroid Build Coastguard Worker    }
4420*da0073e9SAndroid Build Coastguard Worker
4421*da0073e9SAndroid Build Coastguard Worker    # List of nn.functionals with Tensor inputs but not with type annotation
4422*da0073e9SAndroid Build Coastguard Worker    FUNCTIONALS_WITHOUT_ANNOTATION = (
4423*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool1d",
4424*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool2d",
4425*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool3d",
4426*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool2d",
4427*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool3d",
4428*da0073e9SAndroid Build Coastguard Worker        "max_pool1d",
4429*da0073e9SAndroid Build Coastguard Worker        "max_pool2d",
4430*da0073e9SAndroid Build Coastguard Worker        "max_pool3d",
4431*da0073e9SAndroid Build Coastguard Worker        "gaussian_nll_loss",
4432*da0073e9SAndroid Build Coastguard Worker        "upsample",
4433*da0073e9SAndroid Build Coastguard Worker        "upsample_bilinear",
4434*da0073e9SAndroid Build Coastguard Worker        "upsample_nearest",
4435*da0073e9SAndroid Build Coastguard Worker    )
4436*da0073e9SAndroid Build Coastguard Worker
4437*da0073e9SAndroid Build Coastguard Worker    # Inconsistent behavior between Python 3.8 and other Python versions:
4438*da0073e9SAndroid Build Coastguard Worker    # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
4439*da0073e9SAndroid Build Coastguard Worker    # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
4440*da0073e9SAndroid Build Coastguard Worker    #                 internal exception above
4441*da0073e9SAndroid Build Coastguard Worker    # Use the following map to override the expected exception for Python 3.8
4442*da0073e9SAndroid Build Coastguard Worker    UNTRACEABLE_FUNCTIONALS_PY38 = {
4443*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool1d": PROXY_ITERATED,
4444*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool2d": PROXY_ITERATED,
4445*da0073e9SAndroid Build Coastguard Worker        "adaptive_max_pool3d": PROXY_ITERATED,
4446*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool2d": PROXY_ITERATED,
4447*da0073e9SAndroid Build Coastguard Worker        "fractional_max_pool3d": PROXY_ITERATED,
4448*da0073e9SAndroid Build Coastguard Worker        "max_pool1d": PROXY_ITERATED,
4449*da0073e9SAndroid Build Coastguard Worker        "max_pool2d": PROXY_ITERATED,
4450*da0073e9SAndroid Build Coastguard Worker        "max_pool3d": PROXY_ITERATED,
4451*da0073e9SAndroid Build Coastguard Worker
4452*da0073e9SAndroid Build Coastguard Worker        "group_norm": CONTROL_FLOW
4453*da0073e9SAndroid Build Coastguard Worker    }
4454*da0073e9SAndroid Build Coastguard Worker
4455*da0073e9SAndroid Build Coastguard Worker    @classmethod
4456*da0073e9SAndroid Build Coastguard Worker    def _get_functional(cls):
4457*da0073e9SAndroid Build Coastguard Worker        functional_list = []
4458*da0073e9SAndroid Build Coastguard Worker        for f in dir(torch.nn.functional):
4459*da0073e9SAndroid Build Coastguard Worker            if not f.islower():
4460*da0073e9SAndroid Build Coastguard Worker                continue
4461*da0073e9SAndroid Build Coastguard Worker            # Ignore internal functions
4462*da0073e9SAndroid Build Coastguard Worker            if f.startswith('_'):
4463*da0073e9SAndroid Build Coastguard Worker                continue
4464*da0073e9SAndroid Build Coastguard Worker            # Ignore supporting functions
4465*da0073e9SAndroid Build Coastguard Worker            if f in cls.IGNORE_FUNCS:
4466*da0073e9SAndroid Build Coastguard Worker                continue
4467*da0073e9SAndroid Build Coastguard Worker            fn = getattr(torch.nn.functional, f)
4468*da0073e9SAndroid Build Coastguard Worker            # Ignore non-callable object like modules
4469*da0073e9SAndroid Build Coastguard Worker            if not isinstance(fn, Callable):
4470*da0073e9SAndroid Build Coastguard Worker                continue
4471*da0073e9SAndroid Build Coastguard Worker            if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
4472*da0073e9SAndroid Build Coastguard Worker                try:
4473*da0073e9SAndroid Build Coastguard Worker                    sig = inspect.signature(fn)
4474*da0073e9SAndroid Build Coastguard Worker                    has_tensor_arg = False
4475*da0073e9SAndroid Build Coastguard Worker                    for param in sig.parameters.values():
4476*da0073e9SAndroid Build Coastguard Worker                        if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
4477*da0073e9SAndroid Build Coastguard Worker                            has_tensor_arg = True
4478*da0073e9SAndroid Build Coastguard Worker                    if not has_tensor_arg:
4479*da0073e9SAndroid Build Coastguard Worker                        continue
4480*da0073e9SAndroid Build Coastguard Worker                # No signature or Object is not supported
4481*da0073e9SAndroid Build Coastguard Worker                except ValueError:
4482*da0073e9SAndroid Build Coastguard Worker                    pass
4483*da0073e9SAndroid Build Coastguard Worker            functional_list.append((f, fn))
4484*da0073e9SAndroid Build Coastguard Worker        return functional_list
4485*da0073e9SAndroid Build Coastguard Worker
4486*da0073e9SAndroid Build Coastguard Worker    @classmethod
4487*da0073e9SAndroid Build Coastguard Worker    def generate_test_func(cls, func_name, fn):
4488*da0073e9SAndroid Build Coastguard Worker
4489*da0073e9SAndroid Build Coastguard Worker        def functional_test(self):
4490*da0073e9SAndroid Build Coastguard Worker            if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
4491*da0073e9SAndroid Build Coastguard Worker                    sys.version_info >= (3, 8) and sys.version_info < (3, 12):
4492*da0073e9SAndroid Build Coastguard Worker                exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
4493*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(exc, err):
4494*da0073e9SAndroid Build Coastguard Worker                    symbolic_trace(fn)
4495*da0073e9SAndroid Build Coastguard Worker            elif func_name in self.UNTRACEABLE_FUNCTIONALS:
4496*da0073e9SAndroid Build Coastguard Worker                exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
4497*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(exc, err):
4498*da0073e9SAndroid Build Coastguard Worker                    symbolic_trace(fn)
4499*da0073e9SAndroid Build Coastguard Worker            else:
4500*da0073e9SAndroid Build Coastguard Worker                symbolic_trace(fn)
4501*da0073e9SAndroid Build Coastguard Worker        return functional_test
4502*da0073e9SAndroid Build Coastguard Worker
4503*da0073e9SAndroid Build Coastguard Worker    @classmethod
4504*da0073e9SAndroid Build Coastguard Worker    def generate_tests(cls):
4505*da0073e9SAndroid Build Coastguard Worker        functional_list = cls._get_functional()
4506*da0073e9SAndroid Build Coastguard Worker        for func_name, fn in functional_list:
4507*da0073e9SAndroid Build Coastguard Worker            test_name = "test_nn_functional_" + func_name
4508*da0073e9SAndroid Build Coastguard Worker            functional_test = cls.generate_test_func(func_name, fn)
4509*da0073e9SAndroid Build Coastguard Worker            setattr(cls, test_name, functional_test)
4510*da0073e9SAndroid Build Coastguard Worker
4511*da0073e9SAndroid Build Coastguard Worker    @classmethod
4512*da0073e9SAndroid Build Coastguard Worker    def setUpClass(cls):
4513*da0073e9SAndroid Build Coastguard Worker
4514*da0073e9SAndroid Build Coastguard Worker        def no(*args, **kwargs):
4515*da0073e9SAndroid Build Coastguard Worker            return False
4516*da0073e9SAndroid Build Coastguard Worker
4517*da0073e9SAndroid Build Coastguard Worker        for name in cls.TO_PATCH.keys():
4518*da0073e9SAndroid Build Coastguard Worker            cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
4519*da0073e9SAndroid Build Coastguard Worker            setattr(torch.nn.functional, name, no)
4520*da0073e9SAndroid Build Coastguard Worker
4521*da0073e9SAndroid Build Coastguard Worker    @classmethod
4522*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
4523*da0073e9SAndroid Build Coastguard Worker        for name in cls.TO_PATCH.keys():
4524*da0073e9SAndroid Build Coastguard Worker            setattr(torch.nn.functional, name, cls.TO_PATCH[name])
4525*da0073e9SAndroid Build Coastguard Worker
4526*da0073e9SAndroid Build Coastguard WorkerTestFunctionalTracing.generate_tests()
4527*da0073e9SAndroid Build Coastguard Worker
4528*da0073e9SAndroid Build Coastguard Worker
4529*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOperatorSignatures, globals())
4530*da0073e9SAndroid Build Coastguard Worker
4531*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("too slow")
4532*da0073e9SAndroid Build Coastguard Worker@skipIfNoTorchVision
4533*da0073e9SAndroid Build Coastguard Workerclass TestVisionTracing(JitTestCase):
4534*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
4535*da0073e9SAndroid Build Coastguard Worker        # Checking for mutable operations while tracing is feature flagged
4536*da0073e9SAndroid Build Coastguard Worker        # Enable it in testing but not by default
4537*da0073e9SAndroid Build Coastguard Worker        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4538*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = True
4539*da0073e9SAndroid Build Coastguard Worker
4540*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
4541*da0073e9SAndroid Build Coastguard Worker        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4542*da0073e9SAndroid Build Coastguard Worker
4543*da0073e9SAndroid Build Coastguard Worker    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4544*da0073e9SAndroid Build Coastguard Worker    INCONSISTENT_TYPE = (
4545*da0073e9SAndroid Build Coastguard Worker        RuntimeError,
4546*da0073e9SAndroid Build Coastguard Worker        r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
4547*da0073e9SAndroid Build Coastguard Worker    )
4548*da0073e9SAndroid Build Coastguard Worker
4549*da0073e9SAndroid Build Coastguard Worker    UNTRACEABLE_MODELS = {
4550*da0073e9SAndroid Build Coastguard Worker        "fasterrcnn_resnet50_fpn": PROXY_ITERATED,
4551*da0073e9SAndroid Build Coastguard Worker        "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4552*da0073e9SAndroid Build Coastguard Worker        "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
4553*da0073e9SAndroid Build Coastguard Worker        "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
4554*da0073e9SAndroid Build Coastguard Worker        "maskrcnn_resnet50_fpn": PROXY_ITERATED,
4555*da0073e9SAndroid Build Coastguard Worker        "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4556*da0073e9SAndroid Build Coastguard Worker        "keypointrcnn_resnet50_fpn": PROXY_ITERATED,
4557*da0073e9SAndroid Build Coastguard Worker        "retinanet_resnet50_fpn": PROXY_ITERATED,
4558*da0073e9SAndroid Build Coastguard Worker        "retinanet_resnet50_fpn_v2": PROXY_ITERATED,
4559*da0073e9SAndroid Build Coastguard Worker        "ssd300_vgg16": PROXY_ITERATED,
4560*da0073e9SAndroid Build Coastguard Worker        "fcos_resnet50_fpn": PROXY_ITERATED,
4561*da0073e9SAndroid Build Coastguard Worker        "ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
4562*da0073e9SAndroid Build Coastguard Worker    }
4563*da0073e9SAndroid Build Coastguard Worker    UNSCRIPTABLE_MODELS = {
4564*da0073e9SAndroid Build Coastguard Worker        "googlenet": INCONSISTENT_TYPE,
4565*da0073e9SAndroid Build Coastguard Worker        "inception_v3": INCONSISTENT_TYPE,
4566*da0073e9SAndroid Build Coastguard Worker    }
4567*da0073e9SAndroid Build Coastguard Worker
4568*da0073e9SAndroid Build Coastguard Worker    output_transform = {
4569*da0073e9SAndroid Build Coastguard Worker        "fcn_resnet50": lambda x: x["out"],
4570*da0073e9SAndroid Build Coastguard Worker        "fcn_resnet101": lambda x: x["out"],
4571*da0073e9SAndroid Build Coastguard Worker        "deeplabv3_resnet50": lambda x: x["out"],
4572*da0073e9SAndroid Build Coastguard Worker        "deeplabv3_resnet101": lambda x: x["out"],
4573*da0073e9SAndroid Build Coastguard Worker        "deeplabv3_mobilenet_v3_large": lambda x: x["out"],
4574*da0073e9SAndroid Build Coastguard Worker        "lraspp_mobilenet_v3_large": lambda x: x["out"],
4575*da0073e9SAndroid Build Coastguard Worker        "fasterrcnn_resnet50_fpn": lambda x: x[1],
4576*da0073e9SAndroid Build Coastguard Worker        "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4577*da0073e9SAndroid Build Coastguard Worker        "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4578*da0073e9SAndroid Build Coastguard Worker        "maskrcnn_resnet50_fpn": lambda x: x[1],
4579*da0073e9SAndroid Build Coastguard Worker        "keypointrcnn_resnet50_fpn": lambda x: x[1],
4580*da0073e9SAndroid Build Coastguard Worker        "retinanet_resnet50_fpn": lambda x: x[1],
4581*da0073e9SAndroid Build Coastguard Worker    }
4582*da0073e9SAndroid Build Coastguard Worker
4583*da0073e9SAndroid Build Coastguard Worker    @classmethod
4584*da0073e9SAndroid Build Coastguard Worker    def generate_test_fn(cls, name, x, kwargs):
4585*da0073e9SAndroid Build Coastguard Worker        def run_test(self):
4586*da0073e9SAndroid Build Coastguard Worker            model = torchvision_models.get_model(name, **kwargs)
4587*da0073e9SAndroid Build Coastguard Worker            model = model.eval()
4588*da0073e9SAndroid Build Coastguard Worker            if name in self.UNTRACEABLE_MODELS:
4589*da0073e9SAndroid Build Coastguard Worker                err, exc = self.UNTRACEABLE_MODELS[name]
4590*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(err, exc):
4591*da0073e9SAndroid Build Coastguard Worker                    graph = symbolic_trace(model)
4592*da0073e9SAndroid Build Coastguard Worker            else:
4593*da0073e9SAndroid Build Coastguard Worker                out_transform = self.output_transform.get(name, lambda x: x)
4594*da0073e9SAndroid Build Coastguard Worker                graph : torch.fx.GraphModule = symbolic_trace(model)
4595*da0073e9SAndroid Build Coastguard Worker                a = out_transform(model(x))
4596*da0073e9SAndroid Build Coastguard Worker                b = out_transform(graph(x))
4597*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, b)
4598*da0073e9SAndroid Build Coastguard Worker
4599*da0073e9SAndroid Build Coastguard Worker                if name in self.UNSCRIPTABLE_MODELS:
4600*da0073e9SAndroid Build Coastguard Worker                    err, exc = self.UNSCRIPTABLE_MODELS[name]
4601*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(err, exc):
4602*da0073e9SAndroid Build Coastguard Worker                        script = torch.jit.script(graph)
4603*da0073e9SAndroid Build Coastguard Worker                else:
4604*da0073e9SAndroid Build Coastguard Worker                    script = torch.jit.script(graph)
4605*da0073e9SAndroid Build Coastguard Worker                    c = out_transform(script(x))
4606*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a, c)
4607*da0073e9SAndroid Build Coastguard Worker
4608*da0073e9SAndroid Build Coastguard Worker        return run_test
4609*da0073e9SAndroid Build Coastguard Worker
4610*da0073e9SAndroid Build Coastguard Worker    @classmethod
4611*da0073e9SAndroid Build Coastguard Worker    def generate_classification_tests(cls):
4612*da0073e9SAndroid Build Coastguard Worker        for k in torchvision_models.list_models(module=torchvision_models):
4613*da0073e9SAndroid Build Coastguard Worker            test_name = 'test_torchvision_models_' + k
4614*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
4615*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(num_classes=50)
4616*da0073e9SAndroid Build Coastguard Worker            model_test = cls.generate_test_fn(k, x, kwargs)
4617*da0073e9SAndroid Build Coastguard Worker            setattr(cls, test_name, model_test)
4618*da0073e9SAndroid Build Coastguard Worker
4619*da0073e9SAndroid Build Coastguard Worker    @classmethod
4620*da0073e9SAndroid Build Coastguard Worker    def generate_segmentation_tests(cls):
4621*da0073e9SAndroid Build Coastguard Worker        for k in torchvision_models.list_models(module=torchvision_models.segmentation):
4622*da0073e9SAndroid Build Coastguard Worker            test_name = 'test_torchvision_models_segmentation_' + k
4623*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(1, 3, 32, 32)
4624*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(num_classes=10, pretrained_backbone=False)
4625*da0073e9SAndroid Build Coastguard Worker            model_test = cls.generate_test_fn(k, x, kwargs)
4626*da0073e9SAndroid Build Coastguard Worker            setattr(cls, test_name, model_test)
4627*da0073e9SAndroid Build Coastguard Worker
4628*da0073e9SAndroid Build Coastguard Worker    @classmethod
4629*da0073e9SAndroid Build Coastguard Worker    def generate_detection_tests(cls):
4630*da0073e9SAndroid Build Coastguard Worker        for k in torchvision_models.list_models(module=torchvision_models.detection):
4631*da0073e9SAndroid Build Coastguard Worker            test_name = 'test_torchvision_models_detection_' + k
4632*da0073e9SAndroid Build Coastguard Worker            x = [torch.rand(3, 300, 300)]
4633*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(num_classes=10, pretrained_backbone=False)
4634*da0073e9SAndroid Build Coastguard Worker            model_test = cls.generate_test_fn(k, x, kwargs)
4635*da0073e9SAndroid Build Coastguard Worker            setattr(cls, test_name, model_test)
4636*da0073e9SAndroid Build Coastguard Worker
4637*da0073e9SAndroid Build Coastguard Worker    @classmethod
4638*da0073e9SAndroid Build Coastguard Worker    def generate_video_tests(cls):
4639*da0073e9SAndroid Build Coastguard Worker        for k in torchvision_models.list_models(module=torchvision_models.video):
4640*da0073e9SAndroid Build Coastguard Worker            test_name = 'test_torchvision_models_video_' + k
4641*da0073e9SAndroid Build Coastguard Worker            x = (
4642*da0073e9SAndroid Build Coastguard Worker                torch.rand(1, 3, 4, 112, 112)
4643*da0073e9SAndroid Build Coastguard Worker                if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
4644*da0073e9SAndroid Build Coastguard Worker                else torch.rand(1, 3, 16, 224, 224)
4645*da0073e9SAndroid Build Coastguard Worker            )
4646*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(num_classes=50)
4647*da0073e9SAndroid Build Coastguard Worker            model_test = cls.generate_test_fn(k, x, kwargs)
4648*da0073e9SAndroid Build Coastguard Worker            setattr(cls, test_name, model_test)
4649*da0073e9SAndroid Build Coastguard Worker
4650*da0073e9SAndroid Build Coastguard Worker    @classmethod
4651*da0073e9SAndroid Build Coastguard Worker    def generate_tests(cls):
4652*da0073e9SAndroid Build Coastguard Worker        cls.generate_classification_tests()
4653*da0073e9SAndroid Build Coastguard Worker        cls.generate_detection_tests()
4654*da0073e9SAndroid Build Coastguard Worker        cls.generate_segmentation_tests()
4655*da0073e9SAndroid Build Coastguard Worker        cls.generate_video_tests()
4656*da0073e9SAndroid Build Coastguard Worker
4657*da0073e9SAndroid Build Coastguard Workerif HAS_TORCHVISION:
4658*da0073e9SAndroid Build Coastguard Worker    TestVisionTracing.generate_tests()
4659*da0073e9SAndroid Build Coastguard Worker
4660*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
4661*da0073e9SAndroid Build Coastguard Worker    run_tests()
4662