xref: /aosp_15_r20/external/pytorch/test/dynamo/test_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport contextlib
5*da0073e9SAndroid Build Coastguard Workerimport copy
6*da0073e9SAndroid Build Coastguard Workerimport itertools
7*da0073e9SAndroid Build Coastguard Workerimport os
8*da0073e9SAndroid Build Coastguard Workerimport tempfile
9*da0073e9SAndroid Build Coastguard Workerimport traceback
10*da0073e9SAndroid Build Coastguard Workerimport types
11*da0073e9SAndroid Build Coastguard Workerimport unittest
12*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
13*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
14*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, NamedTuple, Tuple
15*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
19*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
20*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
21*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.debug_utils import same_two_models
22*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.eval_frame import unsupported
23*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.mutation_guard import GenerationTracker
24*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import expectedFailureDynamic, same
25*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
26*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.modules.lazy import LazyModuleMixin
27*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parameter import Parameter, UninitializedParameter
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workertry:
31*da0073e9SAndroid Build Coastguard Worker    from . import test_functions
32*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
33*da0073e9SAndroid Build Coastguard Worker    import test_functions
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker_variable = 0
37*da0073e9SAndroid Build Coastguard Worker_variable1 = 0
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerdef update_global():
41*da0073e9SAndroid Build Coastguard Worker    global _variable, _variable1
42*da0073e9SAndroid Build Coastguard Worker    _variable += 1
43*da0073e9SAndroid Build Coastguard Worker    _variable1 += 1
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerclass BasicModule(torch.nn.Module):
47*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
48*da0073e9SAndroid Build Coastguard Worker        super().__init__()
49*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
50*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
53*da0073e9SAndroid Build Coastguard Worker        return F.relu(self.linear1(x)) * self.scale
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerclass FnMember(torch.nn.Module):
57*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
58*da0073e9SAndroid Build Coastguard Worker        super().__init__()
59*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
60*da0073e9SAndroid Build Coastguard Worker        self.activation = F.relu
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
63*da0073e9SAndroid Build Coastguard Worker        x = self.linear1(x)
64*da0073e9SAndroid Build Coastguard Worker        if self.activation:
65*da0073e9SAndroid Build Coastguard Worker            x = self.activation(x)
66*da0073e9SAndroid Build Coastguard Worker        return x
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerclass FnMemberCmp(torch.nn.Module):
70*da0073e9SAndroid Build Coastguard Worker    def __init__(self, activation):
71*da0073e9SAndroid Build Coastguard Worker        super().__init__()
72*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
73*da0073e9SAndroid Build Coastguard Worker        self.activation = activation
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
76*da0073e9SAndroid Build Coastguard Worker        x = self.linear1(x)
77*da0073e9SAndroid Build Coastguard Worker        if self.activation is not None:
78*da0073e9SAndroid Build Coastguard Worker            x = self.activation(x)
79*da0073e9SAndroid Build Coastguard Worker        if self.activation is None:
80*da0073e9SAndroid Build Coastguard Worker            x = torch.sigmoid(x)
81*da0073e9SAndroid Build Coastguard Worker        return x
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Workerclass SubmoduleExample(torch.nn.Module):
85*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
86*da0073e9SAndroid Build Coastguard Worker        super().__init__()
87*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
88*da0073e9SAndroid Build Coastguard Worker        self.layer2 = BasicModule()
89*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
92*da0073e9SAndroid Build Coastguard Worker        x = self.layer1(x)
93*da0073e9SAndroid Build Coastguard Worker        x = self.layer2(x)
94*da0073e9SAndroid Build Coastguard Worker        return x * self.scale
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workerclass IsTrainingCheck(torch.nn.Module):
98*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
99*da0073e9SAndroid Build Coastguard Worker        super().__init__()
100*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
101*da0073e9SAndroid Build Coastguard Worker        self.linear2 = torch.nn.Linear(10, 10)
102*da0073e9SAndroid Build Coastguard Worker        self.train(True)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
105*da0073e9SAndroid Build Coastguard Worker        if self.training:
106*da0073e9SAndroid Build Coastguard Worker            mod = self.linear1
107*da0073e9SAndroid Build Coastguard Worker        else:
108*da0073e9SAndroid Build Coastguard Worker            mod = self.linear2
109*da0073e9SAndroid Build Coastguard Worker        return F.relu(mod(x))
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerclass IsEvalCheck(IsTrainingCheck):
113*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
114*da0073e9SAndroid Build Coastguard Worker        super().__init__()
115*da0073e9SAndroid Build Coastguard Worker        self.train(False)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Workerclass ModuleMethodCall(torch.nn.Module):
119*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
120*da0073e9SAndroid Build Coastguard Worker        super().__init__()
121*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
122*da0073e9SAndroid Build Coastguard Worker        self.layer2 = BasicModule()
123*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    def call_and_scale(self, mod, x):
126*da0073e9SAndroid Build Coastguard Worker        x = mod(x)
127*da0073e9SAndroid Build Coastguard Worker        return x * self.scale
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
130*da0073e9SAndroid Build Coastguard Worker        x1 = self.call_and_scale(self.layer1, x)
131*da0073e9SAndroid Build Coastguard Worker        x2 = self.call_and_scale(self.layer2, x)
132*da0073e9SAndroid Build Coastguard Worker        return x1 + x2
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Workerclass UnsupportedMethodCall(torch.nn.Module):
136*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
137*da0073e9SAndroid Build Coastguard Worker        super().__init__()
138*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
139*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    def call_and_scale(self, mod, x):
142*da0073e9SAndroid Build Coastguard Worker        x = mod(x)
143*da0073e9SAndroid Build Coastguard Worker        x = x * self.scale
144*da0073e9SAndroid Build Coastguard Worker        return unsupported(x, x)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
147*da0073e9SAndroid Build Coastguard Worker        x1 = self.call_and_scale(self.layer1, x)
148*da0073e9SAndroid Build Coastguard Worker        return x + x1
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Workerclass UnsupportedModule(torch.nn.Module):
152*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
153*da0073e9SAndroid Build Coastguard Worker        super().__init__()
154*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
155*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
158*da0073e9SAndroid Build Coastguard Worker        x = self.layer1(x) * self.scale
159*da0073e9SAndroid Build Coastguard Worker        return unsupported(x, x)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Workerclass UnsupportedModuleCall(torch.nn.Module):
163*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
164*da0073e9SAndroid Build Coastguard Worker        super().__init__()
165*da0073e9SAndroid Build Coastguard Worker        self.mod = UnsupportedModule()
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
168*da0073e9SAndroid Build Coastguard Worker        return 1 + self.mod(x * 1.5)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Workerclass ModuleWithStaticForward(torch.nn.Module):
172*da0073e9SAndroid Build Coastguard Worker    @staticmethod
173*da0073e9SAndroid Build Coastguard Worker    def forward(x):
174*da0073e9SAndroid Build Coastguard Worker        return x * torch.sigmoid(x)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Workerclass ModuleCallModuleWithStaticForward(torch.nn.Module):
178*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
179*da0073e9SAndroid Build Coastguard Worker        super().__init__()
180*da0073e9SAndroid Build Coastguard Worker        self.mod = ModuleWithStaticForward()
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
183*da0073e9SAndroid Build Coastguard Worker        return self.mod(x)
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerclass ModuleStaticMethodCall(torch.nn.Module):
187*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
188*da0073e9SAndroid Build Coastguard Worker        super().__init__()
189*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
190*da0073e9SAndroid Build Coastguard Worker        self.layer2 = BasicModule()
191*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    @staticmethod
194*da0073e9SAndroid Build Coastguard Worker    def call_and_scale(scale, mod, x):
195*da0073e9SAndroid Build Coastguard Worker        x = mod(x)
196*da0073e9SAndroid Build Coastguard Worker        return x * scale
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
199*da0073e9SAndroid Build Coastguard Worker        x1 = self.call_and_scale(self.scale, self.layer1, x)
200*da0073e9SAndroid Build Coastguard Worker        x2 = self.call_and_scale(self.scale, self.layer2, x)
201*da0073e9SAndroid Build Coastguard Worker        return x1 + x2
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Workerclass ModuleClassMethodCall(torch.nn.Module):
205*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
206*da0073e9SAndroid Build Coastguard Worker        super().__init__()
207*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
208*da0073e9SAndroid Build Coastguard Worker        self.layer2 = BasicModule()
209*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    @classmethod
212*da0073e9SAndroid Build Coastguard Worker    def call_and_scale(cls, scale, mod, x):
213*da0073e9SAndroid Build Coastguard Worker        x = mod(x)
214*da0073e9SAndroid Build Coastguard Worker        return x * scale
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
217*da0073e9SAndroid Build Coastguard Worker        x1 = self.call_and_scale(self.scale, self.layer1, x)
218*da0073e9SAndroid Build Coastguard Worker        x2 = self.call_and_scale(self.scale, self.layer2, x)
219*da0073e9SAndroid Build Coastguard Worker        return x1 + x2
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Workerclass ModuleProperty(torch.nn.Module):
223*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
224*da0073e9SAndroid Build Coastguard Worker        super().__init__()
225*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    @property
228*da0073e9SAndroid Build Coastguard Worker    def scale_alias(self):
229*da0073e9SAndroid Build Coastguard Worker        return self.scale
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
232*da0073e9SAndroid Build Coastguard Worker        return x * self.scale_alias
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Workerclass NestedModuleList(torch.nn.Module):
236*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
237*da0073e9SAndroid Build Coastguard Worker        super().__init__()
238*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.ModuleList([])
239*da0073e9SAndroid Build Coastguard Worker        for _ in range(3):
240*da0073e9SAndroid Build Coastguard Worker            self.layers.append(
241*da0073e9SAndroid Build Coastguard Worker                torch.nn.ModuleList(
242*da0073e9SAndroid Build Coastguard Worker                    [
243*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(10, 10),
244*da0073e9SAndroid Build Coastguard Worker                        torch.nn.ReLU(),
245*da0073e9SAndroid Build Coastguard Worker                    ]
246*da0073e9SAndroid Build Coastguard Worker                )
247*da0073e9SAndroid Build Coastguard Worker            )
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
250*da0073e9SAndroid Build Coastguard Worker        for layer, act in self.layers:
251*da0073e9SAndroid Build Coastguard Worker            x = act(layer(x))
252*da0073e9SAndroid Build Coastguard Worker        return x
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Workerclass ConstLoop(torch.nn.Module):
256*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
257*da0073e9SAndroid Build Coastguard Worker        super().__init__()
258*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
259*da0073e9SAndroid Build Coastguard Worker        self.count = 3
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
262*da0073e9SAndroid Build Coastguard Worker        for i in range(self.count):
263*da0073e9SAndroid Build Coastguard Worker            x = torch.sigmoid(self.linear1(x))
264*da0073e9SAndroid Build Coastguard Worker        return x
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Workerclass ViaModuleCall(torch.nn.Module):
268*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
269*da0073e9SAndroid Build Coastguard Worker        super().__init__()
270*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
273*da0073e9SAndroid Build Coastguard Worker        return test_functions.constant3(torch.sigmoid(self.linear1(x)), x)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Workerclass IsNoneLayer(torch.nn.Module):
277*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
278*da0073e9SAndroid Build Coastguard Worker        super().__init__()
279*da0073e9SAndroid Build Coastguard Worker        self.layer1 = torch.nn.Linear(10, 10)
280*da0073e9SAndroid Build Coastguard Worker        self.layer2 = None
281*da0073e9SAndroid Build Coastguard Worker        self.train(True)
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
284*da0073e9SAndroid Build Coastguard Worker        if self.layer1 is not None:
285*da0073e9SAndroid Build Coastguard Worker            x = self.layer1(x)
286*da0073e9SAndroid Build Coastguard Worker        if self.layer2 is not None:
287*da0073e9SAndroid Build Coastguard Worker            x = self.layer2(x)
288*da0073e9SAndroid Build Coastguard Worker        return x
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Workerclass LayerList(torch.nn.Module):
292*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
293*da0073e9SAndroid Build Coastguard Worker        super().__init__()
294*da0073e9SAndroid Build Coastguard Worker        self.layers = [
295*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
296*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
297*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
298*da0073e9SAndroid Build Coastguard Worker        ]
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
301*da0073e9SAndroid Build Coastguard Worker        for layer in self.layers:
302*da0073e9SAndroid Build Coastguard Worker            x = layer(x)
303*da0073e9SAndroid Build Coastguard Worker        return x
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Workerclass ModuleList(torch.nn.Module):
307*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
308*da0073e9SAndroid Build Coastguard Worker        super().__init__()
309*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.ModuleList(
310*da0073e9SAndroid Build Coastguard Worker            [
311*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(10, 10),
312*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU(),
313*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(10, 10),
314*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU(),
315*da0073e9SAndroid Build Coastguard Worker            ]
316*da0073e9SAndroid Build Coastguard Worker        )
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
319*da0073e9SAndroid Build Coastguard Worker        for i in range(len(self.layers)):
320*da0073e9SAndroid Build Coastguard Worker            x = self.layers[i](x)
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        for layer in self.layers:
323*da0073e9SAndroid Build Coastguard Worker            x = layer(x)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        for layer, val in zip(self.layers, (x, x, x, x)):
326*da0073e9SAndroid Build Coastguard Worker            x = layer(x) + val
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        for layer, val in zip(self.layers, (1, 2, 3, 4)):
329*da0073e9SAndroid Build Coastguard Worker            x = layer(x) + val
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        for idx, layer in enumerate(self.layers):
332*da0073e9SAndroid Build Coastguard Worker            x = layer(x) * idx
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        for idx, layer in enumerate(self.layers[::-1]):
335*da0073e9SAndroid Build Coastguard Worker            x = layer(x) * idx
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        return x
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Workerclass CustomGetItemModuleList(torch.nn.Module):
341*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
342*da0073e9SAndroid Build Coastguard Worker        super().__init__()
343*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.ModuleList(
344*da0073e9SAndroid Build Coastguard Worker            [
345*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(10, 10),
346*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU(),
347*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(10, 10),
348*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU(),
349*da0073e9SAndroid Build Coastguard Worker            ]
350*da0073e9SAndroid Build Coastguard Worker        )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx: int):
353*da0073e9SAndroid Build Coastguard Worker        return self.layers[idx]
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    def __len__(self) -> int:
356*da0073e9SAndroid Build Coastguard Worker        return len(self.layers)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
359*da0073e9SAndroid Build Coastguard Worker        for i in range(len(self)):
360*da0073e9SAndroid Build Coastguard Worker            x = self[i](x)
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        return x
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Workerclass ModuleDict(torch.nn.Module):
366*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
367*da0073e9SAndroid Build Coastguard Worker        super().__init__()
368*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.ModuleDict(
369*da0073e9SAndroid Build Coastguard Worker            {
370*da0073e9SAndroid Build Coastguard Worker                "0": torch.nn.Linear(10, 10),
371*da0073e9SAndroid Build Coastguard Worker            }
372*da0073e9SAndroid Build Coastguard Worker        )
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
375*da0073e9SAndroid Build Coastguard Worker        # TODO(future PR): handle more logic
376*da0073e9SAndroid Build Coastguard Worker        x = self.layers["0"](x)
377*da0073e9SAndroid Build Coastguard Worker        return x
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Workerclass ParameterDict(torch.nn.Module):
381*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
382*da0073e9SAndroid Build Coastguard Worker        super().__init__()
383*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.ParameterDict(
384*da0073e9SAndroid Build Coastguard Worker            {
385*da0073e9SAndroid Build Coastguard Worker                "0": torch.nn.Parameter(torch.randn(10, 10)),
386*da0073e9SAndroid Build Coastguard Worker            }
387*da0073e9SAndroid Build Coastguard Worker        )
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
390*da0073e9SAndroid Build Coastguard Worker        x = self.layers["0"].mm(x)
391*da0073e9SAndroid Build Coastguard Worker        return x
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Workerclass CustomGetItemParameterDict(torch.nn.Module):
395*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
396*da0073e9SAndroid Build Coastguard Worker        super().__init__()
397*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.ParameterDict(
398*da0073e9SAndroid Build Coastguard Worker            {
399*da0073e9SAndroid Build Coastguard Worker                "0": torch.nn.Parameter(torch.randn(10, 10)),
400*da0073e9SAndroid Build Coastguard Worker            }
401*da0073e9SAndroid Build Coastguard Worker        )
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, key: str) -> torch.nn.Module:
404*da0073e9SAndroid Build Coastguard Worker        return self.layers[key]
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
407*da0073e9SAndroid Build Coastguard Worker        x = self["0"].mm(x)
408*da0073e9SAndroid Build Coastguard Worker        return x
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Workerclass CustomGetItemModuleDict(torch.nn.Module):
412*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
413*da0073e9SAndroid Build Coastguard Worker        super().__init__()
414*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.ModuleDict(
415*da0073e9SAndroid Build Coastguard Worker            {
416*da0073e9SAndroid Build Coastguard Worker                "0": torch.nn.Linear(10, 10),
417*da0073e9SAndroid Build Coastguard Worker            }
418*da0073e9SAndroid Build Coastguard Worker        )
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, key: str) -> torch.nn.Module:
421*da0073e9SAndroid Build Coastguard Worker        return self.layers[key]
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
424*da0073e9SAndroid Build Coastguard Worker        x = self["0"](x)
425*da0073e9SAndroid Build Coastguard Worker        return x
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Workerclass TensorList(torch.nn.Module):
429*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
430*da0073e9SAndroid Build Coastguard Worker        super().__init__()
431*da0073e9SAndroid Build Coastguard Worker        self.layers = (
432*da0073e9SAndroid Build Coastguard Worker            torch.randn((1, 10)),
433*da0073e9SAndroid Build Coastguard Worker            torch.randn((10, 1)),
434*da0073e9SAndroid Build Coastguard Worker            torch.randn((1, 10)),
435*da0073e9SAndroid Build Coastguard Worker            torch.randn((10, 1)),
436*da0073e9SAndroid Build Coastguard Worker        )
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
439*da0073e9SAndroid Build Coastguard Worker        for layer in self.layers:
440*da0073e9SAndroid Build Coastguard Worker            x = x * layer
441*da0073e9SAndroid Build Coastguard Worker        return x
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Workerclass Children(torch.nn.Module):
445*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
446*da0073e9SAndroid Build Coastguard Worker        super().__init__()
447*da0073e9SAndroid Build Coastguard Worker        self.l1 = torch.nn.Linear(10, 10)
448*da0073e9SAndroid Build Coastguard Worker        self.l2 = torch.nn.ReLU()
449*da0073e9SAndroid Build Coastguard Worker        self.l3 = torch.nn.Linear(10, 10)
450*da0073e9SAndroid Build Coastguard Worker        self.l4 = torch.nn.ReLU()
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
453*da0073e9SAndroid Build Coastguard Worker        for block in self.children():
454*da0073e9SAndroid Build Coastguard Worker            x = block(x)
455*da0073e9SAndroid Build Coastguard Worker        return x
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Workerclass NamedChildren(torch.nn.Module):
459*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
460*da0073e9SAndroid Build Coastguard Worker        super().__init__()
461*da0073e9SAndroid Build Coastguard Worker        self.l1 = torch.nn.Linear(10, 10)
462*da0073e9SAndroid Build Coastguard Worker        self.l2 = torch.nn.ReLU()
463*da0073e9SAndroid Build Coastguard Worker        self.l3 = torch.nn.Linear(10, 10)
464*da0073e9SAndroid Build Coastguard Worker        self.l4 = torch.nn.ReLU()
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
467*da0073e9SAndroid Build Coastguard Worker        for _, block in self.named_children():
468*da0073e9SAndroid Build Coastguard Worker            x = block(x)
469*da0073e9SAndroid Build Coastguard Worker        return x
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Workerclass IntArg(torch.nn.Module):
473*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
474*da0073e9SAndroid Build Coastguard Worker        super().__init__()
475*da0073e9SAndroid Build Coastguard Worker        self.layer1 = torch.nn.Linear(10, 10)
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, offset=1):
478*da0073e9SAndroid Build Coastguard Worker        x = F.relu(self.layer1(x)) + offset
479*da0073e9SAndroid Build Coastguard Worker        return x
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Workerclass Seq(torch.nn.Module):
483*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
484*da0073e9SAndroid Build Coastguard Worker        super().__init__()
485*da0073e9SAndroid Build Coastguard Worker        self.layers = torch.nn.Sequential(
486*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
487*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
488*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
489*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
490*da0073e9SAndroid Build Coastguard Worker        )
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
493*da0073e9SAndroid Build Coastguard Worker        return self.layers(x)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Workerclass Cfg:
497*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
498*da0073e9SAndroid Build Coastguard Worker        self.val = 0.5
499*da0073e9SAndroid Build Coastguard Worker        self.count = 3
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Workerclass CfgModule(torch.nn.Module):
503*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
504*da0073e9SAndroid Build Coastguard Worker        super().__init__()
505*da0073e9SAndroid Build Coastguard Worker        self.cfg = Cfg()
506*da0073e9SAndroid Build Coastguard Worker        self.layer = torch.nn.Linear(10, 10)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
509*da0073e9SAndroid Build Coastguard Worker        for i in range(self.cfg.count):
510*da0073e9SAndroid Build Coastguard Worker            x = self.layer(x + self.cfg.val)
511*da0073e9SAndroid Build Coastguard Worker        return x
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Workerclass StringMember(torch.nn.Module):
515*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
516*da0073e9SAndroid Build Coastguard Worker        super().__init__()
517*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
518*da0073e9SAndroid Build Coastguard Worker        self.mode = "some_string"
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
521*da0073e9SAndroid Build Coastguard Worker        if self.mode == "some_string":
522*da0073e9SAndroid Build Coastguard Worker            return F.relu(self.linear1(x))
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Workerclass _Block(torch.nn.Module):
526*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
527*da0073e9SAndroid Build Coastguard Worker        return 1.5 * torch.cat(x, 1)
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Workerclass _DenseBlock(torch.nn.ModuleDict):
531*da0073e9SAndroid Build Coastguard Worker    _version = 2
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker    def __init__(
534*da0073e9SAndroid Build Coastguard Worker        self,
535*da0073e9SAndroid Build Coastguard Worker        num_layers: int = 3,
536*da0073e9SAndroid Build Coastguard Worker    ) -> None:
537*da0073e9SAndroid Build Coastguard Worker        super().__init__()
538*da0073e9SAndroid Build Coastguard Worker        for i in range(num_layers):
539*da0073e9SAndroid Build Coastguard Worker            self.add_module("denselayer%d" % (i + 1), _Block())
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker    def forward(self, init_features):
542*da0073e9SAndroid Build Coastguard Worker        features = [init_features]
543*da0073e9SAndroid Build Coastguard Worker        for layer in self.values():
544*da0073e9SAndroid Build Coastguard Worker            new_features = layer(features)
545*da0073e9SAndroid Build Coastguard Worker            features.append(new_features)
546*da0073e9SAndroid Build Coastguard Worker        return torch.cat(features, 1)
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Workerclass DenseNetBlocks(torch.nn.Module):
550*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
551*da0073e9SAndroid Build Coastguard Worker        super().__init__()
552*da0073e9SAndroid Build Coastguard Worker        self.layers = _DenseBlock()
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
555*da0073e9SAndroid Build Coastguard Worker        return self.layers(x)
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Workerclass MaterializedModule(torch.nn.Module):
559*da0073e9SAndroid Build Coastguard Worker    """Once the below lazy module is initialized with its first input,
560*da0073e9SAndroid Build Coastguard Worker    it is transformed into this module."""
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    param: Parameter
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
565*da0073e9SAndroid Build Coastguard Worker        super().__init__()
566*da0073e9SAndroid Build Coastguard Worker        self.register_parameter("param", None)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
569*da0073e9SAndroid Build Coastguard Worker        return x
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Workerclass LazyModule(LazyModuleMixin, MaterializedModule):
573*da0073e9SAndroid Build Coastguard Worker    param: UninitializedParameter
574*da0073e9SAndroid Build Coastguard Worker    cls_to_become = MaterializedModule
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
577*da0073e9SAndroid Build Coastguard Worker        super().__init__()
578*da0073e9SAndroid Build Coastguard Worker        self.param = UninitializedParameter()
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    def initialize_parameters(self, x):
581*da0073e9SAndroid Build Coastguard Worker        # force graph break to ensure this was not inlined
582*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.graph_break()
583*da0073e9SAndroid Build Coastguard Worker        self.param.materialize(x.shape)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Workerclass LazyMLP(torch.nn.Module):
587*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
588*da0073e9SAndroid Build Coastguard Worker        super().__init__()
589*da0073e9SAndroid Build Coastguard Worker        self.fc1 = torch.nn.LazyLinear(10)
590*da0073e9SAndroid Build Coastguard Worker        self.relu1 = torch.nn.ReLU()
591*da0073e9SAndroid Build Coastguard Worker        self.fc2 = torch.nn.LazyLinear(1)
592*da0073e9SAndroid Build Coastguard Worker        self.relu2 = torch.nn.ReLU()
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
595*da0073e9SAndroid Build Coastguard Worker        x = self.relu1(self.fc1(input))
596*da0073e9SAndroid Build Coastguard Worker        y = self.relu2(self.fc2(x))
597*da0073e9SAndroid Build Coastguard Worker        return y
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Workerclass MyInput(NamedTuple):
601*da0073e9SAndroid Build Coastguard Worker    x: Dict[str, Dict[str, torch.Tensor]]
602*da0073e9SAndroid Build Coastguard Worker    y: torch.Tensor
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Workerclass LazyLayerWithNamedTupleInput(LazyModuleMixin, torch.nn.Module):
606*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
607*da0073e9SAndroid Build Coastguard Worker        super().__init__()
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    def initialize_parameters(self, input):
610*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
611*da0073e9SAndroid Build Coastguard Worker            self._param = torch.nn.Parameter(
612*da0073e9SAndroid Build Coastguard Worker                torch.empty(input.x["a"][0].shape).fill_(0.5)
613*da0073e9SAndroid Build Coastguard Worker            )
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
616*da0073e9SAndroid Build Coastguard Worker        input = input.x["a"]
617*da0073e9SAndroid Build Coastguard Worker        x = 0
618*da0073e9SAndroid Build Coastguard Worker        for i in range(len(input)):
619*da0073e9SAndroid Build Coastguard Worker            x = x + input[i]
620*da0073e9SAndroid Build Coastguard Worker        return x
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Workerclass LazyModuleWithNamedTupleInput(torch.nn.Module):
624*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
625*da0073e9SAndroid Build Coastguard Worker        super().__init__()
626*da0073e9SAndroid Build Coastguard Worker        self.layer = LazyLayerWithNamedTupleInput()
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
629*da0073e9SAndroid Build Coastguard Worker        return self.layer(input)
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Workerclass LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module):
633*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
634*da0073e9SAndroid Build Coastguard Worker        super().__init__()
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker    def initialize_parameters(self, input):
637*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
638*da0073e9SAndroid Build Coastguard Worker            self._param = torch.nn.Parameter(torch.empty(input[0].shape).fill_(0.5))
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
641*da0073e9SAndroid Build Coastguard Worker        x = 0
642*da0073e9SAndroid Build Coastguard Worker        for i in range(len(input)):
643*da0073e9SAndroid Build Coastguard Worker            x = x + input[i]
644*da0073e9SAndroid Build Coastguard Worker        return x
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Workerclass LazyModuleWithListInput(torch.nn.Module):
648*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
649*da0073e9SAndroid Build Coastguard Worker        super().__init__()
650*da0073e9SAndroid Build Coastguard Worker        self.layer = LazyLayerWithListInput()
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
653*da0073e9SAndroid Build Coastguard Worker        return self.layer(input[:-1])
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Workerclass LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module):
657*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
658*da0073e9SAndroid Build Coastguard Worker        super().__init__()
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker    def initialize_parameters(self, input):
661*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
662*da0073e9SAndroid Build Coastguard Worker            self.layer = LazyLayerWithListInput()
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
665*da0073e9SAndroid Build Coastguard Worker        return self.layer(x)
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Workerclass LazyLayerWithInputs(LazyModuleMixin, torch.nn.Module):
669*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
670*da0073e9SAndroid Build Coastguard Worker        super().__init__()
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker    def initialize_parameters(self, x, y):
673*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
674*da0073e9SAndroid Build Coastguard Worker            self._param_x = torch.nn.Parameter(torch.empty(x[0].shape).fill_(0.5))
675*da0073e9SAndroid Build Coastguard Worker            self._param_y = torch.nn.Parameter(torch.empty(y[0].shape).fill_(0.5))
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, y):
678*da0073e9SAndroid Build Coastguard Worker        res_x = 0
679*da0073e9SAndroid Build Coastguard Worker        for i in range(len(x)):
680*da0073e9SAndroid Build Coastguard Worker            res_x = res_x + x[i]
681*da0073e9SAndroid Build Coastguard Worker        res_y = 0
682*da0073e9SAndroid Build Coastguard Worker        for i in range(len(y)):
683*da0073e9SAndroid Build Coastguard Worker            res_y = res_y + y[i]
684*da0073e9SAndroid Build Coastguard Worker        return res_x + res_y
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Workerclass LazyModuleKwArgs(LazyModuleMixin, torch.nn.Module):
688*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
689*da0073e9SAndroid Build Coastguard Worker        super().__init__()
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker    def initialize_parameters(self, *args, **kwargs):
692*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
693*da0073e9SAndroid Build Coastguard Worker            self.layer = LazyLayerWithInputs()
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, y):
696*da0073e9SAndroid Build Coastguard Worker        return self.layer(x, y=y)
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Workerclass LazyParentModule(LazyModuleMixin, torch.nn.Module):
700*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
701*da0073e9SAndroid Build Coastguard Worker        super().__init__()
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker    def impl(self, x):
704*da0073e9SAndroid Build Coastguard Worker        return x.cos() + self._val
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Workerclass LazyChildModuleNoClsToBecome(LazyParentModule):
708*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
709*da0073e9SAndroid Build Coastguard Worker        super().__init__()
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
712*da0073e9SAndroid Build Coastguard Worker        return super().impl(x.sin())
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker    def initialize_parameters(self, input):
715*da0073e9SAndroid Build Coastguard Worker        self._val = torch.nn.Parameter(torch.ones(2, 2))
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Workerdef requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
719*da0073e9SAndroid Build Coastguard Worker    requires_grad = any(p.requires_grad for p in module.parameters(recurse))
720*da0073e9SAndroid Build Coastguard Worker    return requires_grad
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Workerdef requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool:
724*da0073e9SAndroid Build Coastguard Worker    requires_grad = any(p.requires_grad for p in module.parameters(recurse))
725*da0073e9SAndroid Build Coastguard Worker    return requires_grad
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Workerclass ParametersModule1(torch.nn.Module):
729*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
730*da0073e9SAndroid Build Coastguard Worker        super().__init__()
731*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
732*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.nn.Parameter(torch.randn(1, 10))
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
735*da0073e9SAndroid Build Coastguard Worker        if not requires_grad1(self):
736*da0073e9SAndroid Build Coastguard Worker            return F.relu(self.linear1(x)) * self.scale
737*da0073e9SAndroid Build Coastguard Worker        else:
738*da0073e9SAndroid Build Coastguard Worker            return x + 1
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Workerclass ParametersModule2(ParametersModule1):
742*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
743*da0073e9SAndroid Build Coastguard Worker        if not requires_grad2(self):
744*da0073e9SAndroid Build Coastguard Worker            return F.relu(self.linear1(x)) * self.scale
745*da0073e9SAndroid Build Coastguard Worker        else:
746*da0073e9SAndroid Build Coastguard Worker            return x + 1
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Workerclass ParametersModule3(ParametersModule1):
750*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
751*da0073e9SAndroid Build Coastguard Worker        ones = torch.ones(10, dtype=next(self.parameters()).dtype)
752*da0073e9SAndroid Build Coastguard Worker        return F.relu(self.linear1(x)) * self.scale + ones
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Workerclass ParametersModule4(ParametersModule1):
756*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
757*da0073e9SAndroid Build Coastguard Worker        ones = torch.ones(10, dtype=next(self.parameters(recurse=False)).dtype)
758*da0073e9SAndroid Build Coastguard Worker        return F.relu(self.linear1(x)) * self.scale + ones
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Workerclass ParametersModule5(torch.nn.Module):
762*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
763*da0073e9SAndroid Build Coastguard Worker        super().__init__()
764*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
765*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.nn.Parameter(torch.randn(10, 10))
766*da0073e9SAndroid Build Coastguard Worker        self.scale_dup = self.scale
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
769*da0073e9SAndroid Build Coastguard Worker        counter = 0
770*da0073e9SAndroid Build Coastguard Worker        for param in self.parameters():
771*da0073e9SAndroid Build Coastguard Worker            counter += 1
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        return x * self.scale * counter
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Workerclass SuperModule(BasicModule):
777*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
778*da0073e9SAndroid Build Coastguard Worker        x = super().forward(x)
779*da0073e9SAndroid Build Coastguard Worker        return x + 10.0
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Workerclass SuperModule2(BasicModule):
783*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
784*da0073e9SAndroid Build Coastguard Worker        return BasicModule.forward(self, x)
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Workerclass ComplicatedSuperParent(torch.nn.Module):
788*da0073e9SAndroid Build Coastguard Worker    @classmethod
789*da0073e9SAndroid Build Coastguard Worker    def custom_add(cls, x):
790*da0073e9SAndroid Build Coastguard Worker        x = x + x
791*da0073e9SAndroid Build Coastguard Worker        return x
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Workerclass SuperChildCallsClassMethod(ComplicatedSuperParent):
795*da0073e9SAndroid Build Coastguard Worker    @classmethod
796*da0073e9SAndroid Build Coastguard Worker    def child_func(cls, x):
797*da0073e9SAndroid Build Coastguard Worker        x = super().custom_add(x)
798*da0073e9SAndroid Build Coastguard Worker        return x
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
801*da0073e9SAndroid Build Coastguard Worker        x = self.child_func(x)
802*da0073e9SAndroid Build Coastguard Worker        return x
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Workerclass HasAttrModule(torch.nn.Module):
806*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
807*da0073e9SAndroid Build Coastguard Worker        super().__init__()
808*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.nn.Parameter(torch.randn(1, 10))
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
811*da0073e9SAndroid Build Coastguard Worker        x = F.relu(x)
812*da0073e9SAndroid Build Coastguard Worker        if hasattr(self, "scale"):
813*da0073e9SAndroid Build Coastguard Worker            x *= self.scale
814*da0073e9SAndroid Build Coastguard Worker        if hasattr(self, "scale2"):
815*da0073e9SAndroid Build Coastguard Worker            x *= self.scale2
816*da0073e9SAndroid Build Coastguard Worker        return x
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Workerclass EnumValues(torch.nn.ModuleDict):
820*da0073e9SAndroid Build Coastguard Worker    def __init__(
821*da0073e9SAndroid Build Coastguard Worker        self,
822*da0073e9SAndroid Build Coastguard Worker        num_layers: int = 3,
823*da0073e9SAndroid Build Coastguard Worker    ) -> None:
824*da0073e9SAndroid Build Coastguard Worker        super().__init__()
825*da0073e9SAndroid Build Coastguard Worker        for i in range(num_layers):
826*da0073e9SAndroid Build Coastguard Worker            self.add_module("denselayer%d" % (i + 1), _Block())
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker    def forward(self, init_features):
829*da0073e9SAndroid Build Coastguard Worker        features = [init_features]
830*da0073e9SAndroid Build Coastguard Worker        for idx, layer in enumerate(self.values()):
831*da0073e9SAndroid Build Coastguard Worker            new_features = layer(features)
832*da0073e9SAndroid Build Coastguard Worker            features.append(new_features)
833*da0073e9SAndroid Build Coastguard Worker        return torch.cat(features, 1)
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Workerclass AccessByKeys(torch.nn.ModuleDict):
837*da0073e9SAndroid Build Coastguard Worker    def __init__(
838*da0073e9SAndroid Build Coastguard Worker        self,
839*da0073e9SAndroid Build Coastguard Worker        num_layers: int = 3,
840*da0073e9SAndroid Build Coastguard Worker    ) -> None:
841*da0073e9SAndroid Build Coastguard Worker        super().__init__()
842*da0073e9SAndroid Build Coastguard Worker        for i in range(num_layers):
843*da0073e9SAndroid Build Coastguard Worker            self.add_module("denselayer%d" % (i + 1), _Block())
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker    def forward(self, init_features):
846*da0073e9SAndroid Build Coastguard Worker        features = [init_features]
847*da0073e9SAndroid Build Coastguard Worker        for k in self.keys():
848*da0073e9SAndroid Build Coastguard Worker            new_features = self[k](features)
849*da0073e9SAndroid Build Coastguard Worker            features.append(new_features)
850*da0073e9SAndroid Build Coastguard Worker        return torch.cat(features, 1)
851*da0073e9SAndroid Build Coastguard Worker
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Workerclass CallForwardDirectly(torch.nn.Module):
854*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
855*da0073e9SAndroid Build Coastguard Worker        super().__init__()
856*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
857*da0073e9SAndroid Build Coastguard Worker        self.layer2 = torch.nn.Linear(10, 10)
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
860*da0073e9SAndroid Build Coastguard Worker        x = self.layer1.forward(x)
861*da0073e9SAndroid Build Coastguard Worker        x = self.layer2.forward(x)
862*da0073e9SAndroid Build Coastguard Worker        return x
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Workerclass ConvCallForwardDirectly(torch.nn.Module):
866*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
867*da0073e9SAndroid Build Coastguard Worker        super().__init__()
868*da0073e9SAndroid Build Coastguard Worker        self.layer = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False)
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
871*da0073e9SAndroid Build Coastguard Worker        return self.layer.forward(x)
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Workerclass ConvTransposeCallForwardDirectly(torch.nn.Module):
875*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
876*da0073e9SAndroid Build Coastguard Worker        super().__init__()
877*da0073e9SAndroid Build Coastguard Worker        self.layer = torch.nn.ConvTranspose2d(4, 4, 4)
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
880*da0073e9SAndroid Build Coastguard Worker        return self.layer.forward(x)
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Workerclass ConvCallSuperForwardDirectly(torch.nn.Conv1d):
884*da0073e9SAndroid Build Coastguard Worker    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
885*da0073e9SAndroid Build Coastguard Worker        super().__init__(
886*da0073e9SAndroid Build Coastguard Worker            in_channels=in_channels,
887*da0073e9SAndroid Build Coastguard Worker            out_channels=out_channels,
888*da0073e9SAndroid Build Coastguard Worker            kernel_size=kernel_size,
889*da0073e9SAndroid Build Coastguard Worker            **kwargs,
890*da0073e9SAndroid Build Coastguard Worker        )
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker    def forward(self, inputs, mask=None):
893*da0073e9SAndroid Build Coastguard Worker        outputs = super().forward(inputs)
894*da0073e9SAndroid Build Coastguard Worker        return outputs
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Workerclass ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
898*da0073e9SAndroid Build Coastguard Worker    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
899*da0073e9SAndroid Build Coastguard Worker        super().__init__(
900*da0073e9SAndroid Build Coastguard Worker            in_channels=in_channels,
901*da0073e9SAndroid Build Coastguard Worker            out_channels=out_channels,
902*da0073e9SAndroid Build Coastguard Worker            kernel_size=kernel_size,
903*da0073e9SAndroid Build Coastguard Worker            **kwargs,
904*da0073e9SAndroid Build Coastguard Worker        )
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
907*da0073e9SAndroid Build Coastguard Worker        if x.numel() > 0:
908*da0073e9SAndroid Build Coastguard Worker            return super().forward(x)
909*da0073e9SAndroid Build Coastguard Worker        output_shape = [
910*da0073e9SAndroid Build Coastguard Worker            ((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op)
911*da0073e9SAndroid Build Coastguard Worker            for i, p, di, k, d, op in zip(
912*da0073e9SAndroid Build Coastguard Worker                x.shape[-2:],
913*da0073e9SAndroid Build Coastguard Worker                self.padding,
914*da0073e9SAndroid Build Coastguard Worker                self.dilation,
915*da0073e9SAndroid Build Coastguard Worker                self.kernel_size,
916*da0073e9SAndroid Build Coastguard Worker                self.stride,
917*da0073e9SAndroid Build Coastguard Worker                self.output_padding,
918*da0073e9SAndroid Build Coastguard Worker            )
919*da0073e9SAndroid Build Coastguard Worker        ]
920*da0073e9SAndroid Build Coastguard Worker        output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
921*da0073e9SAndroid Build Coastguard Worker        return _NewEmptyTensorOp.apply(x, output_shape)  # noqa: F821
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Workerclass ModuleNameString(torch.nn.Module):
925*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
926*da0073e9SAndroid Build Coastguard Worker        super().__init__()
927*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(10, 10)
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
930*da0073e9SAndroid Build Coastguard Worker        if self.__class__.__name__ == "ABC":
931*da0073e9SAndroid Build Coastguard Worker            return 10
932*da0073e9SAndroid Build Coastguard Worker        if self.linear1.__class__.__name__ == "Linear":
933*da0073e9SAndroid Build Coastguard Worker            return F.relu(self.linear1(x) + 10)
934*da0073e9SAndroid Build Coastguard Worker        return 11
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Workerclass SelfMutatingModule(torch.nn.Module):
938*da0073e9SAndroid Build Coastguard Worker    def __init__(self, layer):
939*da0073e9SAndroid Build Coastguard Worker        super().__init__()
940*da0073e9SAndroid Build Coastguard Worker        self.layer = layer
941*da0073e9SAndroid Build Coastguard Worker        self.counter = 0
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
944*da0073e9SAndroid Build Coastguard Worker        result = self.layer(x) + self.counter
945*da0073e9SAndroid Build Coastguard Worker        self.counter += 1
946*da0073e9SAndroid Build Coastguard Worker        return F.relu(result)
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Workerclass ModuleAttributePrecedenceBase(torch.nn.Module):
950*da0073e9SAndroid Build Coastguard Worker    def linear(self, x, flag=None):
951*da0073e9SAndroid Build Coastguard Worker        if flag:
952*da0073e9SAndroid Build Coastguard Worker            return x * 2.0
953*da0073e9SAndroid Build Coastguard Worker        return x * 3.0
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Workerclass ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
957*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
958*da0073e9SAndroid Build Coastguard Worker        super().__init__()
959*da0073e9SAndroid Build Coastguard Worker        self.activation = torch.nn.ReLU()
960*da0073e9SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(10, 10)
961*da0073e9SAndroid Build Coastguard Worker        self.initializer = torch.ones([10, 10])
962*da0073e9SAndroid Build Coastguard Worker        self.scale = 0.5
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker    def activation(self, x):
965*da0073e9SAndroid Build Coastguard Worker        return x * 1.2
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker    def initializer(self):
968*da0073e9SAndroid Build Coastguard Worker        return torch.zeros([10, 10])
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker    def scale(self):
971*da0073e9SAndroid Build Coastguard Worker        return 2.0
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
974*da0073e9SAndroid Build Coastguard Worker        # object attribute takes precedence unless it's a nn.Module
975*da0073e9SAndroid Build Coastguard Worker        return self.activation(self.linear(self.initializer + x)) * self.scale
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Workerclass ModuleForwardHasGraphBreak(torch.nn.Module):
979*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
980*da0073e9SAndroid Build Coastguard Worker        super().__init__()
981*da0073e9SAndroid Build Coastguard Worker        self.layer1 = BasicModule()
982*da0073e9SAndroid Build Coastguard Worker        self.layer2 = BasicModule()
983*da0073e9SAndroid Build Coastguard Worker        self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
984*da0073e9SAndroid Build Coastguard Worker        self.layer4 = torch.nn.ModuleList(
985*da0073e9SAndroid Build Coastguard Worker            [
986*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(10, 10),
987*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU(),
988*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(10, 10),
989*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU(),
990*da0073e9SAndroid Build Coastguard Worker            ]
991*da0073e9SAndroid Build Coastguard Worker        )
992*da0073e9SAndroid Build Coastguard Worker        self.layer5 = torch.nn.ModuleDict(
993*da0073e9SAndroid Build Coastguard Worker            {
994*da0073e9SAndroid Build Coastguard Worker                "0": torch.nn.Linear(10, 10),
995*da0073e9SAndroid Build Coastguard Worker            }
996*da0073e9SAndroid Build Coastguard Worker        )
997*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.randn(1, 10)
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1000*da0073e9SAndroid Build Coastguard Worker        """
1001*da0073e9SAndroid Build Coastguard Worker        This is used to test if the results of functions like `named_parameters`
1002*da0073e9SAndroid Build Coastguard Worker        can be reconstructed correctly after graph break.
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker        https://github.com/pytorch/torchdynamo/issues/1931
1005*da0073e9SAndroid Build Coastguard Worker        """
1006*da0073e9SAndroid Build Coastguard Worker        x = self.layer1(x)
1007*da0073e9SAndroid Build Coastguard Worker        params1 = dict(self.named_parameters())
1008*da0073e9SAndroid Build Coastguard Worker        params2 = list(self.parameters())
1009*da0073e9SAndroid Build Coastguard Worker        buffers1 = dict(self.named_buffers())
1010*da0073e9SAndroid Build Coastguard Worker        buffers2 = list(self.buffers())
1011*da0073e9SAndroid Build Coastguard Worker        modules1 = dict(self.named_modules())
1012*da0073e9SAndroid Build Coastguard Worker        modules2 = list(self.modules())
1013*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.graph_break()
1014*da0073e9SAndroid Build Coastguard Worker        y = modules2
1015*da0073e9SAndroid Build Coastguard Worker        y = modules1
1016*da0073e9SAndroid Build Coastguard Worker        y = buffers2
1017*da0073e9SAndroid Build Coastguard Worker        y = buffers1
1018*da0073e9SAndroid Build Coastguard Worker        y = params2
1019*da0073e9SAndroid Build Coastguard Worker        y = params1
1020*da0073e9SAndroid Build Coastguard Worker        x = (
1021*da0073e9SAndroid Build Coastguard Worker            self.layer2(x)
1022*da0073e9SAndroid Build Coastguard Worker            + y["layer3.1.linear1.weight"]
1023*da0073e9SAndroid Build Coastguard Worker            + y["layer4.2.weight"]
1024*da0073e9SAndroid Build Coastguard Worker            + y["layer5.0.weight"]
1025*da0073e9SAndroid Build Coastguard Worker        )
1026*da0073e9SAndroid Build Coastguard Worker        return x * self.scale
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Workerclass ModuleGuardNameIsValid(torch.nn.ModuleDict):
1030*da0073e9SAndroid Build Coastguard Worker    # Guard names should be valid python identifier as we use eval() to get
1031*da0073e9SAndroid Build Coastguard Worker    # corresponding guard value. Some guard names come from source(module path)
1032*da0073e9SAndroid Build Coastguard Worker    # where special symbols are valid. But they are not valid python identifier,
1033*da0073e9SAndroid Build Coastguard Worker    # we should identify these pattern and rewrite them with getattr.
1034*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1035*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1036*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
1037*da0073e9SAndroid Build Coastguard Worker            self.add_module("l@yer-%d" % (i + 1), BasicModule())
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1040*da0073e9SAndroid Build Coastguard Worker        for layer in self.values():
1041*da0073e9SAndroid Build Coastguard Worker            x = layer(x)
1042*da0073e9SAndroid Build Coastguard Worker        return x
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Workerclass SequentialWithDuplicatedModule(torch.nn.Module):
1046*da0073e9SAndroid Build Coastguard Worker    # Sequential module(self.layer) contains three duplicated ReLU module.
1047*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1048*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1049*da0073e9SAndroid Build Coastguard Worker        self.relu = torch.nn.ReLU()
1050*da0073e9SAndroid Build Coastguard Worker        self.layer = torch.nn.Sequential(
1051*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 20),
1052*da0073e9SAndroid Build Coastguard Worker            self.relu,
1053*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(20, 20),
1054*da0073e9SAndroid Build Coastguard Worker            self.relu,
1055*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(20, 10),
1056*da0073e9SAndroid Build Coastguard Worker            self.relu,
1057*da0073e9SAndroid Build Coastguard Worker        )
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1060*da0073e9SAndroid Build Coastguard Worker        return self.layer(x)
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Workerclass SequentialWithDuplicatedModule2(torch.nn.Module):
1064*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1065*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1066*da0073e9SAndroid Build Coastguard Worker        self.relu = torch.nn.ReLU()
1067*da0073e9SAndroid Build Coastguard Worker        self.layer = torch.nn.Sequential(
1068*da0073e9SAndroid Build Coastguard Worker            collections.OrderedDict(
1069*da0073e9SAndroid Build Coastguard Worker                [
1070*da0073e9SAndroid Build Coastguard Worker                    ("linear1", torch.nn.Linear(10, 20)),
1071*da0073e9SAndroid Build Coastguard Worker                    ("relu1", self.relu),
1072*da0073e9SAndroid Build Coastguard Worker                    ("linear2", torch.nn.Linear(20, 20)),
1073*da0073e9SAndroid Build Coastguard Worker                    ("relu2", self.relu),
1074*da0073e9SAndroid Build Coastguard Worker                    ("linear3", torch.nn.Linear(20, 10)),
1075*da0073e9SAndroid Build Coastguard Worker                    ("relu3", self.relu),
1076*da0073e9SAndroid Build Coastguard Worker                ]
1077*da0073e9SAndroid Build Coastguard Worker            )
1078*da0073e9SAndroid Build Coastguard Worker        )
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1081*da0073e9SAndroid Build Coastguard Worker        return self.layer(x)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Workerclass ModuleComparison(torch.nn.Module):
1085*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1086*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1087*da0073e9SAndroid Build Coastguard Worker        self.layer0 = torch.nn.Linear(10, 10)
1088*da0073e9SAndroid Build Coastguard Worker        self.layer1 = torch.nn.Linear(10, 10)
1089*da0073e9SAndroid Build Coastguard Worker        self.layer2 = torch.nn.Linear(10, 10)
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker    @property
1092*da0073e9SAndroid Build Coastguard Worker    def encoder_layers(self):
1093*da0073e9SAndroid Build Coastguard Worker        return [self.layer0, self.layer1, self.layer2]
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1096*da0073e9SAndroid Build Coastguard Worker        for layer in self.encoder_layers:
1097*da0073e9SAndroid Build Coastguard Worker            output = layer(x)
1098*da0073e9SAndroid Build Coastguard Worker            if layer is None or layer == self.layer0:
1099*da0073e9SAndroid Build Coastguard Worker                output = F.relu6(output)
1100*da0073e9SAndroid Build Coastguard Worker            else:
1101*da0073e9SAndroid Build Coastguard Worker                output = F.relu(output)
1102*da0073e9SAndroid Build Coastguard Worker        return output
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Workerclass ModulePatch1(torch.nn.Module):
1106*da0073e9SAndroid Build Coastguard Worker    pass
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Workerclass ModulePatch2(torch.nn.Module):
1110*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1111*da0073e9SAndroid Build Coastguard Worker        return x - 1
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Workerclass UnspecNonInlinableModule(torch.nn.Module):
1115*da0073e9SAndroid Build Coastguard Worker    torchdynamo_force_dynamic = True  # forced to be a UnspecializedNNModule
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1118*da0073e9SAndroid Build Coastguard Worker        if x.sum() > 0:
1119*da0073e9SAndroid Build Coastguard Worker            return x + 1
1120*da0073e9SAndroid Build Coastguard Worker        else:
1121*da0073e9SAndroid Build Coastguard Worker            return x - 1
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Workerclass UnspecNonInlinableToplevelModule(torch.nn.Module):
1125*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1126*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1127*da0073e9SAndroid Build Coastguard Worker        self.m = UnspecNonInlinableModule()
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1130*da0073e9SAndroid Build Coastguard Worker        return self.m(x)
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker
1133*da0073e9SAndroid Build Coastguard Workerdef make_test(fn, expected_ops=None):
1134*da0073e9SAndroid Build Coastguard Worker    def test_fn(self):
1135*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(
1136*da0073e9SAndroid Build Coastguard Worker            self, fn=fn, nargs=1, expected_ops=expected_ops
1137*da0073e9SAndroid Build Coastguard Worker        )
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker    fn.eval()
1140*da0073e9SAndroid Build Coastguard Worker    return test_fn
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
1144*da0073e9SAndroid Build Coastguard Workerdef temporary_tensor_subclass(torch_function=None):
1145*da0073e9SAndroid Build Coastguard Worker    class TensorProxy(torch.Tensor):
1146*da0073e9SAndroid Build Coastguard Worker        @classmethod
1147*da0073e9SAndroid Build Coastguard Worker        def __torch_function__(cls, func, types, args=(), kwargs=None):
1148*da0073e9SAndroid Build Coastguard Worker            if torch_function is not None:
1149*da0073e9SAndroid Build Coastguard Worker                torch_function()
1150*da0073e9SAndroid Build Coastguard Worker            return super().__torch_function__(func, types, args, kwargs)
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker    torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
1153*da0073e9SAndroid Build Coastguard Worker    try:
1154*da0073e9SAndroid Build Coastguard Worker        yield TensorProxy
1155*da0073e9SAndroid Build Coastguard Worker    finally:
1156*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Workerclass NNModuleTests(torch._dynamo.test_case.TestCase):
1160*da0073e9SAndroid Build Coastguard Worker    test_seq = make_test(Seq())
1161*da0073e9SAndroid Build Coastguard Worker    test_basicmodule1 = make_test(BasicModule())
1162*da0073e9SAndroid Build Coastguard Worker    test_basicmodule2 = make_test(BasicModule())
1163*da0073e9SAndroid Build Coastguard Worker    test_submodules1 = make_test(SubmoduleExample())
1164*da0073e9SAndroid Build Coastguard Worker    test_submodules2 = make_test(SubmoduleExample())
1165*da0073e9SAndroid Build Coastguard Worker    test_modulemethod1 = make_test(ModuleMethodCall())
1166*da0073e9SAndroid Build Coastguard Worker    test_modulemethod2 = make_test(ModuleMethodCall())
1167*da0073e9SAndroid Build Coastguard Worker    test_module_call_module_with_static_forward = make_test(
1168*da0073e9SAndroid Build Coastguard Worker        ModuleCallModuleWithStaticForward()
1169*da0073e9SAndroid Build Coastguard Worker    )
1170*da0073e9SAndroid Build Coastguard Worker    test_module_static_method = make_test(ModuleStaticMethodCall())
1171*da0073e9SAndroid Build Coastguard Worker    test_fnmember = make_test(FnMember())
1172*da0073e9SAndroid Build Coastguard Worker    test_fnmembercmp1 = make_test(FnMemberCmp(F.relu))
1173*da0073e9SAndroid Build Coastguard Worker    test_fnmembercmp2 = make_test(FnMemberCmp(None))
1174*da0073e9SAndroid Build Coastguard Worker    test_constloop = make_test(ConstLoop())
1175*da0073e9SAndroid Build Coastguard Worker    test_istraining1 = make_test(IsTrainingCheck())
1176*da0073e9SAndroid Build Coastguard Worker    test_istraining2 = make_test(IsTrainingCheck())
1177*da0073e9SAndroid Build Coastguard Worker    test_iseval1 = make_test(IsEvalCheck())
1178*da0073e9SAndroid Build Coastguard Worker    test_iseval2 = make_test(IsEvalCheck())
1179*da0073e9SAndroid Build Coastguard Worker    test_viamodulecall = make_test(ViaModuleCall())
1180*da0073e9SAndroid Build Coastguard Worker    test_isnonelayer = make_test(IsNoneLayer())
1181*da0073e9SAndroid Build Coastguard Worker    test_layerlist = make_test(LayerList())
1182*da0073e9SAndroid Build Coastguard Worker    test_tensorlist = make_test(TensorList())
1183*da0073e9SAndroid Build Coastguard Worker    test_intarg = make_test(IntArg())
1184*da0073e9SAndroid Build Coastguard Worker    test_cfgmod = make_test(CfgModule())
1185*da0073e9SAndroid Build Coastguard Worker    test_stringmember = make_test(StringMember())
1186*da0073e9SAndroid Build Coastguard Worker    test_modulelist = make_test(ModuleList())
1187*da0073e9SAndroid Build Coastguard Worker    test_modulelist_nested = make_test(NestedModuleList())
1188*da0073e9SAndroid Build Coastguard Worker    test_modulelist_custom = make_test(CustomGetItemModuleList())
1189*da0073e9SAndroid Build Coastguard Worker    test_moduledict = make_test(ModuleDict())
1190*da0073e9SAndroid Build Coastguard Worker    test_moduledict_custom = make_test(CustomGetItemModuleDict())
1191*da0073e9SAndroid Build Coastguard Worker    test_parameterdict = make_test(ParameterDict())
1192*da0073e9SAndroid Build Coastguard Worker    test_parameterdict_custom = make_test(CustomGetItemParameterDict())
1193*da0073e9SAndroid Build Coastguard Worker    test_super1 = make_test(SuperModule())
1194*da0073e9SAndroid Build Coastguard Worker    test_super2 = make_test(SuperModule2())
1195*da0073e9SAndroid Build Coastguard Worker    test_super_class_method = make_test(SuperChildCallsClassMethod())
1196*da0073e9SAndroid Build Coastguard Worker    test_children = make_test(Children())
1197*da0073e9SAndroid Build Coastguard Worker    test_named_children = make_test(NamedChildren())
1198*da0073e9SAndroid Build Coastguard Worker    test_densenet = make_test(DenseNetBlocks())
1199*da0073e9SAndroid Build Coastguard Worker    test_parameters1 = make_test(ParametersModule1())
1200*da0073e9SAndroid Build Coastguard Worker    test_parameters2 = make_test(ParametersModule2())
1201*da0073e9SAndroid Build Coastguard Worker    test_parameters3 = make_test(ParametersModule3(), expected_ops=5)
1202*da0073e9SAndroid Build Coastguard Worker    test_parameters4 = make_test(ParametersModule4())
1203*da0073e9SAndroid Build Coastguard Worker    test_parameters5 = make_test(ParametersModule5())
1204*da0073e9SAndroid Build Coastguard Worker    test_hasattr = make_test(HasAttrModule())
1205*da0073e9SAndroid Build Coastguard Worker    test_enumvalues = make_test(EnumValues())
1206*da0073e9SAndroid Build Coastguard Worker    test_access_by_keys = make_test(AccessByKeys())
1207*da0073e9SAndroid Build Coastguard Worker    test_module_class_method = make_test(ModuleClassMethodCall())
1208*da0073e9SAndroid Build Coastguard Worker    test_module_property = make_test(ModuleProperty())
1209*da0073e9SAndroid Build Coastguard Worker    test_forward_directly = make_test(CallForwardDirectly())
1210*da0073e9SAndroid Build Coastguard Worker    test_module_name_string = make_test(ModuleNameString())
1211*da0073e9SAndroid Build Coastguard Worker    test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
1212*da0073e9SAndroid Build Coastguard Worker    test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid())
1213*da0073e9SAndroid Build Coastguard Worker    test_sequential_with_duplicated_module = make_test(SequentialWithDuplicatedModule())
1214*da0073e9SAndroid Build Coastguard Worker    test_sequential_with_duplicated_module2 = make_test(
1215*da0073e9SAndroid Build Coastguard Worker        SequentialWithDuplicatedModule2()
1216*da0073e9SAndroid Build Coastguard Worker    )
1217*da0073e9SAndroid Build Coastguard Worker    test_module_comparison = make_test(ModuleComparison())
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker    def test_module_forward_has_graph_break(self):
1220*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardHasGraphBreak()
1221*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([10, 10])
1222*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1223*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager")(m)
1224*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1225*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker    def test_unsupportedmethod(self):
1228*da0073e9SAndroid Build Coastguard Worker        m = UnsupportedMethodCall()
1229*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(10)
1230*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1231*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnt)(m)
1232*da0073e9SAndroid Build Coastguard Worker        r = opt_m(i)
1233*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(r, m(i)))
1234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 5)
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker    def test_unsupportedmodule(self):
1237*da0073e9SAndroid Build Coastguard Worker        m = UnsupportedModuleCall()
1238*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(10)
1239*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1240*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnt)(m)
1241*da0073e9SAndroid Build Coastguard Worker        r = opt_m(i)
1242*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(r, m(i)))
1243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 6)
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker    def test_self_mutating1(self):
1246*da0073e9SAndroid Build Coastguard Worker        m1 = torch.nn.Linear(10, 10)
1247*da0073e9SAndroid Build Coastguard Worker        m2 = SelfMutatingModule(m1)
1248*da0073e9SAndroid Build Coastguard Worker        m3 = SelfMutatingModule(m1)
1249*da0073e9SAndroid Build Coastguard Worker        m4 = SelfMutatingModule(m1)
1250*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(10)
1251*da0073e9SAndroid Build Coastguard Worker        out2 = [m2(i), m2(i), m2(i)]
1252*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1253*da0073e9SAndroid Build Coastguard Worker        opt_m3 = torch._dynamo.optimize_assert(cnt)(m3)
1254*da0073e9SAndroid Build Coastguard Worker        opt_m4 = torch._dynamo.optimize_assert(cnt)(m4)
1255*da0073e9SAndroid Build Coastguard Worker        out3 = [opt_m3(i), opt_m3(i), opt_m3(i)]
1256*da0073e9SAndroid Build Coastguard Worker        out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
1257*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(out2, out3))
1258*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(out2, out4))
1259*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1260*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """2""")
1261*da0073e9SAndroid Build Coastguard Worker        else:
1262*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
1265*da0073e9SAndroid Build Coastguard Worker    def test_generation_tag(self):
1266*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker        # guarantee that we have installed
1269*da0073e9SAndroid Build Coastguard Worker        # the generation tagging function
1270*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.optimize_assert(cnt):
1271*da0073e9SAndroid Build Coastguard Worker            pass
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker        m1 = torch.nn.Linear(10, 10)
1274*da0073e9SAndroid Build Coastguard Worker        prev_generation = GenerationTracker.get_generation_value(m1)
1275*da0073e9SAndroid Build Coastguard Worker        cur_generation = prev_generation + 1
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.optimize_assert(cnt):
1278*da0073e9SAndroid Build Coastguard Worker            m2 = torch.nn.Linear(10, 10)
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation)
1281*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation)
1282*da0073e9SAndroid Build Coastguard Worker        # check that newly constructed instances
1283*da0073e9SAndroid Build Coastguard Worker        # also have the same generation (even if copied from an old instance)
1284*da0073e9SAndroid Build Coastguard Worker        m3 = deepcopy(m1)
1285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation)
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker    def test_simple_torch_function(self):
1288*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1289*da0073e9SAndroid Build Coastguard Worker            # function call, twice to test wrapping
1290*da0073e9SAndroid Build Coastguard Worker            x = F.sigmoid(x)
1291*da0073e9SAndroid Build Coastguard Worker            x = F.sigmoid(x)
1292*da0073e9SAndroid Build Coastguard Worker            # method call, twice to test wrapping
1293*da0073e9SAndroid Build Coastguard Worker            x = x.sigmoid()
1294*da0073e9SAndroid Build Coastguard Worker            x = x.sigmoid()
1295*da0073e9SAndroid Build Coastguard Worker            return x
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker        with temporary_tensor_subclass() as TensorProxy:
1298*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1).as_subclass(TensorProxy)
1299*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
1300*da0073e9SAndroid Build Coastguard Worker            out1 = foo(x)
1301*da0073e9SAndroid Build Coastguard Worker            opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
1302*da0073e9SAndroid Build Coastguard Worker            out2 = opt_foo(x)
1303*da0073e9SAndroid Build Coastguard Worker
1304*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.op_count, 4)
1305*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch._dynamo.testing.same(out1, out2))
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_with_closure(self):
1308*da0073e9SAndroid Build Coastguard Worker        def run():
1309*da0073e9SAndroid Build Coastguard Worker            def foo(x):
1310*da0073e9SAndroid Build Coastguard Worker                # function call, twice to test wrapping
1311*da0073e9SAndroid Build Coastguard Worker                x = F.sigmoid(x)
1312*da0073e9SAndroid Build Coastguard Worker                x = F.sigmoid(x)
1313*da0073e9SAndroid Build Coastguard Worker                # method call, twice to test wrapping
1314*da0073e9SAndroid Build Coastguard Worker                x = x.sigmoid()
1315*da0073e9SAndroid Build Coastguard Worker                x = x.sigmoid()
1316*da0073e9SAndroid Build Coastguard Worker                return x
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker            counter = 0
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker            def function():
1321*da0073e9SAndroid Build Coastguard Worker                nonlocal counter
1322*da0073e9SAndroid Build Coastguard Worker                # for now, only support reads from closure cells
1323*da0073e9SAndroid Build Coastguard Worker                # TODO(future PR): support writes as well
1324*da0073e9SAndroid Build Coastguard Worker                counter + 1
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker            with temporary_tensor_subclass(function) as TensorProxy:
1327*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(1).as_subclass(TensorProxy)
1328*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(1)
1329*da0073e9SAndroid Build Coastguard Worker                cnt = torch._dynamo.testing.CompileCounter()
1330*da0073e9SAndroid Build Coastguard Worker                out1 = foo(x)
1331*da0073e9SAndroid Build Coastguard Worker                opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
1332*da0073e9SAndroid Build Coastguard Worker                out2 = opt_foo(x)
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnt.op_count, 4)
1335*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch._dynamo.testing.same(out1, out2))
1336*da0073e9SAndroid Build Coastguard Worker
1337*da0073e9SAndroid Build Coastguard Worker        run()
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker    def test_torch_mangled_class_name(self):
1340*da0073e9SAndroid Build Coastguard Worker        original = TensorWithTFOverrideVariable.global_mangled_class_name
1341*da0073e9SAndroid Build Coastguard Worker        results = []
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker        def instrumented(self, tx):
1344*da0073e9SAndroid Build Coastguard Worker            result = original(self, tx)
1345*da0073e9SAndroid Build Coastguard Worker            results.append(result)
1346*da0073e9SAndroid Build Coastguard Worker            return result
1347*da0073e9SAndroid Build Coastguard Worker
1348*da0073e9SAndroid Build Coastguard Worker        TensorWithTFOverrideVariable.global_mangled_class_name = instrumented
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker        def one_break(x):
1351*da0073e9SAndroid Build Coastguard Worker            x = F.sigmoid(x)
1352*da0073e9SAndroid Build Coastguard Worker            print()  # force break
1353*da0073e9SAndroid Build Coastguard Worker            x = x.sigmoid()
1354*da0073e9SAndroid Build Coastguard Worker            return x
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker        try:
1357*da0073e9SAndroid Build Coastguard Worker            with temporary_tensor_subclass() as TensorProxy:
1358*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(1).as_subclass(TensorProxy)
1359*da0073e9SAndroid Build Coastguard Worker                x1 = one_break(x)
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker                cnt = torch._dynamo.testing.CompileCounter()
1362*da0073e9SAndroid Build Coastguard Worker                opt_one_break = torch._dynamo.optimize(cnt)(one_break)
1363*da0073e9SAndroid Build Coastguard Worker                x2 = opt_one_break(x)
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch._dynamo.testing.same(x1, x2))
1366*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnt.frame_count, 2)
1367*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnt.op_count, 2)
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker                compile_ids = set()
1370*da0073e9SAndroid Build Coastguard Worker                for r in results:
1371*da0073e9SAndroid Build Coastguard Worker                    # A mangled classname looks like __subclass_TensorProxy_94524181138240_c0
1372*da0073e9SAndroid Build Coastguard Worker                    # where the last segment contains the compile_id.
1373*da0073e9SAndroid Build Coastguard Worker                    prefix = "__subclass_TensorProxy_"
1374*da0073e9SAndroid Build Coastguard Worker                    before, sep, after = r.partition(prefix)
1375*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(before, "")
1376*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(sep, prefix)
1377*da0073e9SAndroid Build Coastguard Worker
1378*da0073e9SAndroid Build Coastguard Worker                    class_type_id, compile_id = after.split("_")
1379*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(class_type_id.isnumeric())
1380*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(compile_id.startswith("c"))
1381*da0073e9SAndroid Build Coastguard Worker
1382*da0073e9SAndroid Build Coastguard Worker                    cid = compile_id[1:]
1383*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(cid.isnumeric())
1384*da0073e9SAndroid Build Coastguard Worker                    compile_ids.add(cid)
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(compile_ids), 3)
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker        finally:
1389*da0073e9SAndroid Build Coastguard Worker            TensorWithTFOverrideVariable.global_mangled_class_name = original
1390*da0073e9SAndroid Build Coastguard Worker
1391*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
1392*da0073e9SAndroid Build Coastguard Worker    def test_nn_moduledict_contains(self):
1393*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
1394*da0073e9SAndroid Build Coastguard Worker            def __init__(self, module_dict):
1395*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1396*da0073e9SAndroid Build Coastguard Worker                self.module_dict = module_dict
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1399*da0073e9SAndroid Build Coastguard Worker                if "foo" in self.module_dict:
1400*da0073e9SAndroid Build Coastguard Worker                    x = torch.mul(x, 1.0)
1401*da0073e9SAndroid Build Coastguard Worker                x = torch.add(x, 1.0)
1402*da0073e9SAndroid Build Coastguard Worker                return x
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker        module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})
1405*da0073e9SAndroid Build Coastguard Worker        m = M(module_dict)
1406*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(1)
1407*da0073e9SAndroid Build Coastguard Worker        out1 = m(data)
1408*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1409*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1410*da0073e9SAndroid Build Coastguard Worker        out2 = opt_m(data)
1411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)
1412*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(out1, out2))
1413*da0073e9SAndroid Build Coastguard Worker
1414*da0073e9SAndroid Build Coastguard Worker        module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
1415*da0073e9SAndroid Build Coastguard Worker        m = M(module_dict)
1416*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(1)
1417*da0073e9SAndroid Build Coastguard Worker        out1 = m(data)
1418*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1419*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
1420*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1421*da0073e9SAndroid Build Coastguard Worker        out2 = opt_m(data)
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
1424*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(out1, out2))
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
1427*da0073e9SAndroid Build Coastguard Worker        pre = m(data)
1428*da0073e9SAndroid Build Coastguard Worker        cnt.clear()
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.optimize(cnt, nopython=False):
1431*da0073e9SAndroid Build Coastguard Worker            opt_pre = m(data)
1432*da0073e9SAndroid Build Coastguard Worker            m = M(module_dict)
1433*da0073e9SAndroid Build Coastguard Worker            data = torch.randn(1)
1434*da0073e9SAndroid Build Coastguard Worker            out1 = m(data)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        out_post = m(data)
1437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
1439*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
1440*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(out1, out_post))
1441*da0073e9SAndroid Build Coastguard Worker
1442*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1443*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
1444*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module1(self):
1445*da0073e9SAndroid Build Coastguard Worker        input_shape = (16, 3, 6, 7, 8)
1446*da0073e9SAndroid Build Coastguard Worker
1447*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1448*da0073e9SAndroid Build Coastguard Worker        module = LazyModule()
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Worker        def test_static_module():
1451*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(*input_shape)
1452*da0073e9SAndroid Build Coastguard Worker            module(input)
1453*da0073e9SAndroid Build Coastguard Worker
1454*da0073e9SAndroid Build Coastguard Worker        # test no graph break
1455*da0073e9SAndroid Build Coastguard Worker        opt_test_static_module = torch._dynamo.optimize(cnt, nopython=True)(
1456*da0073e9SAndroid Build Coastguard Worker            test_static_module
1457*da0073e9SAndroid Build Coastguard Worker        )
1458*da0073e9SAndroid Build Coastguard Worker        opt_test_static_module()
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1461*da0073e9SAndroid Build Coastguard Worker            isinstance(module, MaterializedModule),
1462*da0073e9SAndroid Build Coastguard Worker            "Module should be transformed to an instance of MaterializedModule.",
1463*da0073e9SAndroid Build Coastguard Worker        )
1464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.param.shape, input_shape)
1465*da0073e9SAndroid Build Coastguard Worker
1466*da0073e9SAndroid Build Coastguard Worker        # test when mapped to UnspecializedNNModule
1467*da0073e9SAndroid Build Coastguard Worker        module = LazyModule()
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker        def test_unspecialized():
1470*da0073e9SAndroid Build Coastguard Worker            nonlocal module
1471*da0073e9SAndroid Build Coastguard Worker            module = LazyModule()
1472*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(*input_shape)
1473*da0073e9SAndroid Build Coastguard Worker            module(input)
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized)
1476*da0073e9SAndroid Build Coastguard Worker        opt_test_unspecialized()
1477*da0073e9SAndroid Build Coastguard Worker
1478*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1479*da0073e9SAndroid Build Coastguard Worker            isinstance(module, MaterializedModule),
1480*da0073e9SAndroid Build Coastguard Worker            "Module should be transformed to an instance of MaterializedModule.",
1481*da0073e9SAndroid Build Coastguard Worker        )
1482*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.param.shape, input_shape)
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker        # test with a static module in torch.*
1485*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.modules.LazyBatchNorm3d(
1486*da0073e9SAndroid Build Coastguard Worker            affine=False, track_running_stats=False
1487*da0073e9SAndroid Build Coastguard Worker        )
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker        def test_torch_static():
1494*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(*input_shape)
1495*da0073e9SAndroid Build Coastguard Worker            return module(input)  # fully materialized
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker        # test no graph break
1498*da0073e9SAndroid Build Coastguard Worker        opt_test_torch_static = torch._dynamo.optimize(cnt, nopython=True)(
1499*da0073e9SAndroid Build Coastguard Worker            test_torch_static
1500*da0073e9SAndroid Build Coastguard Worker        )
1501*da0073e9SAndroid Build Coastguard Worker        opt_test_torch_static()
1502*da0073e9SAndroid Build Coastguard Worker        out = opt_test_torch_static()
1503*da0073e9SAndroid Build Coastguard Worker
1504*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(out, module(torch.ones(*input_shape))))
1505*da0073e9SAndroid Build Coastguard Worker
1506*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1507*da0073e9SAndroid Build Coastguard Worker            isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d),
1508*da0073e9SAndroid Build Coastguard Worker            "Module should be transformed to an instance of BatchNorm3d.",
1509*da0073e9SAndroid Build Coastguard Worker        )
1510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.")
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1513*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
1514*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module2(self):
1515*da0073e9SAndroid Build Coastguard Worker        # Test FX graph 'call_module' works well if argument is lazy module
1516*da0073e9SAndroid Build Coastguard Worker        m = LazyMLP()
1517*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([10, 10])
1518*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1519*da0073e9SAndroid Build Coastguard Worker        # We should run compile mode firstly, otherwise the module
1520*da0073e9SAndroid Build Coastguard Worker        # would be initialized when running eager mode.
1521*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1522*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1523*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1526*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
1527*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
1528*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module3(self):
1529*da0073e9SAndroid Build Coastguard Worker        m = LazyMLP()
1530*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([10, 10])
1531*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1532*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1533*da0073e9SAndroid Build Coastguard Worker        # first iteration
1534*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1535*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1536*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1537*da0073e9SAndroid Build Coastguard Worker        # move to cuda and second iteration
1538*da0073e9SAndroid Build Coastguard Worker        m = m.to("cuda")
1539*da0073e9SAndroid Build Coastguard Worker        x = x.to("cuda")
1540*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1541*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1542*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
1544*da0073e9SAndroid Build Coastguard Worker
1545*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1546*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
1547*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module4(self):
1548*da0073e9SAndroid Build Coastguard Worker        m = LazyMLP()
1549*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([10, 10])
1550*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1551*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
1552*da0073e9SAndroid Build Coastguard Worker        # first iteration
1553*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1554*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1555*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1556*da0073e9SAndroid Build Coastguard Worker        # input shape changed and second iteration
1557*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([20, 20])
1558*da0073e9SAndroid Build Coastguard Worker        try:
1559*da0073e9SAndroid Build Coastguard Worker            opt_m(x)
1560*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
1561*da0073e9SAndroid Build Coastguard Worker            self.assertIn("must have same reduction dim", traceback.format_exc())
1562*da0073e9SAndroid Build Coastguard Worker
1563*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1564*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
1565*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module5(self):
1566*da0073e9SAndroid Build Coastguard Worker        # Test lazy module works well with list/tuple input
1567*da0073e9SAndroid Build Coastguard Worker        m = LazyModuleWithListInput()
1568*da0073e9SAndroid Build Coastguard Worker        x = [torch.rand([5, 5])] * 3 + [None]
1569*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1570*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1571*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1572*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1573*da0073e9SAndroid Build Coastguard Worker
1574*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1575*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
1576*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module6(self):
1577*da0073e9SAndroid Build Coastguard Worker        # Test new lazy submodule in lazy module's initialize_parameters
1578*da0073e9SAndroid Build Coastguard Worker        m = LazyModuleWithLazySubmodule()
1579*da0073e9SAndroid Build Coastguard Worker        x = [torch.rand([5, 5])] * 3
1580*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1581*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1582*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1583*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1584*da0073e9SAndroid Build Coastguard Worker
1585*da0073e9SAndroid Build Coastguard Worker    # RuntimeError: SymIntArrayRef expected to contain only concrete integers
1586*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
1587*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module7(self):
1588*da0073e9SAndroid Build Coastguard Worker        # Test lazy module works well with namedtuple/dict input
1589*da0073e9SAndroid Build Coastguard Worker        m = LazyModuleWithNamedTupleInput()
1590*da0073e9SAndroid Build Coastguard Worker        x = MyInput(
1591*da0073e9SAndroid Build Coastguard Worker            x={"a": [torch.rand([5, 5])] * 3, "b": torch.rand([5, 5])},
1592*da0073e9SAndroid Build Coastguard Worker            y=torch.rand([5, 5]),
1593*da0073e9SAndroid Build Coastguard Worker        )
1594*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1595*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1596*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1597*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module_no_cls_to_become(self):
1600*da0073e9SAndroid Build Coastguard Worker        # make sure super() works in the case where cls_to_become is None
1601*da0073e9SAndroid Build Coastguard Worker        m = LazyChildModuleNoClsToBecome()
1602*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 2)
1603*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
1604*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1605*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1606*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1607*da0073e9SAndroid Build Coastguard Worker
1608*da0073e9SAndroid Build Coastguard Worker    def test_lazy_module_kwargs(self):
1609*da0073e9SAndroid Build Coastguard Worker        m = LazyModuleKwArgs()
1610*da0073e9SAndroid Build Coastguard Worker        x = [torch.rand([5, 5])] * 3
1611*da0073e9SAndroid Build Coastguard Worker        y = [torch.rand([5, 5])] * 2
1612*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1613*da0073e9SAndroid Build Coastguard Worker        exp_res = m(x, y)
1614*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(exp_res, opt_m(x, y)))
1615*da0073e9SAndroid Build Coastguard Worker
1616*da0073e9SAndroid Build Coastguard Worker    def test_call_fn_with_non_const_inputs_safe(self):
1617*da0073e9SAndroid Build Coastguard Worker        class ModuleSpecialFwd(torch.nn.Module):
1618*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1619*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1620*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(
1621*da0073e9SAndroid Build Coastguard Worker                    in_channels=3, out_channels=20, kernel_size=(5, 5)
1622*da0073e9SAndroid Build Coastguard Worker                )
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker            def _conv_forward(self, x):
1625*da0073e9SAndroid Build Coastguard Worker                return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1628*da0073e9SAndroid Build Coastguard Worker                return self._conv_forward(x)
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker        mod = ModuleSpecialFwd()
1631*da0073e9SAndroid Build Coastguard Worker        rx = torch.randn([3, 10, 10])
1632*da0073e9SAndroid Build Coastguard Worker        real = mod(rx)
1633*da0073e9SAndroid Build Coastguard Worker        graph, _ = torch._dynamo.export(mod)(rx)
1634*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker    def test_conv_call_forward_directly(self):
1637*da0073e9SAndroid Build Coastguard Worker        m = ConvCallForwardDirectly()
1638*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 3, 9, 9])
1639*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1640*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1641*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1642*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_call_forward_directly(self):
1645*da0073e9SAndroid Build Coastguard Worker        m = ConvTransposeCallForwardDirectly()
1646*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4, 4, 4])
1647*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1648*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1649*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1650*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1651*da0073e9SAndroid Build Coastguard Worker
1652*da0073e9SAndroid Build Coastguard Worker    def test_conv_call_super_forward_directly(self):
1653*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4)
1654*da0073e9SAndroid Build Coastguard Worker        m = ConvCallSuperForwardDirectly(4, 4, 4)
1655*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1656*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1657*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1658*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1659*da0073e9SAndroid Build Coastguard Worker
1660*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_call_super_forward_directly(self):
1661*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, 4)
1662*da0073e9SAndroid Build Coastguard Worker        m = ConvTransposeCallSuperForwardDirectly(4, 4, 4)
1663*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
1664*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1665*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
1666*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Workerclass MockModule(torch.nn.Module):
1670*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
1671*da0073e9SAndroid Build Coastguard Worker        super().__init__()
1672*da0073e9SAndroid Build Coastguard Worker        self.relu = torch.nn.ReLU()
1673*da0073e9SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(10, 10)
1674*da0073e9SAndroid Build Coastguard Worker        self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
1675*da0073e9SAndroid Build Coastguard Worker
1676*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
1677*da0073e9SAndroid Build Coastguard Worker        return self.relu(self.linear(x) + self.buf0)
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Workerclass OptimizedModuleTest(torch._dynamo.test_case.TestCase):
1681*da0073e9SAndroid Build Coastguard Worker    def test_nn_module(self):
1682*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
1683*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1684*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize(cnt)(mod)
1685*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
1688*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
1689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
1692*da0073e9SAndroid Build Coastguard Worker    def test_attr_precedence(self):
1693*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
1694*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1695*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1696*da0073e9SAndroid Build Coastguard Worker                self.a = 3
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, c=4):
1699*da0073e9SAndroid Build Coastguard Worker                return x * c
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker            def linear(self, x):
1702*da0073e9SAndroid Build Coastguard Worker                return x
1703*da0073e9SAndroid Build Coastguard Worker
1704*da0073e9SAndroid Build Coastguard Worker            def b(self, x):
1705*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Should not be called")
1706*da0073e9SAndroid Build Coastguard Worker
1707*da0073e9SAndroid Build Coastguard Worker        class MyMod(Mod):
1708*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1709*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1710*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(11, 11)
1711*da0073e9SAndroid Build Coastguard Worker                self.a = 2
1712*da0073e9SAndroid Build Coastguard Worker                self.b = 2
1713*da0073e9SAndroid Build Coastguard Worker                self.scale = 1
1714*da0073e9SAndroid Build Coastguard Worker
1715*da0073e9SAndroid Build Coastguard Worker            def scale(self, x):
1716*da0073e9SAndroid Build Coastguard Worker                # Should not be called because it is shadowed by the instance
1717*da0073e9SAndroid Build Coastguard Worker                # attribute
1718*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Should not be called")
1719*da0073e9SAndroid Build Coastguard Worker
1720*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, c=None):
1721*da0073e9SAndroid Build Coastguard Worker                return self.linear(x) * self.a * self.b * self.scale
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
1724*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(3, 3)
1725*da0073e9SAndroid Build Coastguard Worker        ref = mod(x)
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1728*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend=cnts)
1729*da0073e9SAndroid Build Coastguard Worker        opt_mod(torch.ones(3, 3))
1730*da0073e9SAndroid Build Coastguard Worker        res = opt_mod(torch.ones(3, 3))
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
1736*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
1737*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1738*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize(cnt)(mod)
1739*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
1740*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
1741*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        # Ensure that there is no recompilation
1744*da0073e9SAndroid Build Coastguard Worker        opt_mod(x)
1745*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker        opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
1748*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
1749*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10).to(dtype=torch.float64)
1750*da0073e9SAndroid Build Coastguard Worker        opt_mod(x)
1751*da0073e9SAndroid Build Coastguard Worker        # Ensure that there is a recompilation
1752*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker        # Ensure that there is no recompilation
1755*da0073e9SAndroid Build Coastguard Worker        opt_mod(x)
1756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
1757*da0073e9SAndroid Build Coastguard Worker
1758*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
1759*da0073e9SAndroid Build Coastguard Worker        opt_mod(x)
1760*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
1761*da0073e9SAndroid Build Coastguard Worker
1762*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
1763*da0073e9SAndroid Build Coastguard Worker    def test_param_order(self):
1764*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1765*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1766*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1767*da0073e9SAndroid Build Coastguard Worker                self.param1 = torch.nn.Parameter(torch.ones([1]))
1768*da0073e9SAndroid Build Coastguard Worker                self.param2 = torch.nn.Parameter(torch.ones([2]))
1769*da0073e9SAndroid Build Coastguard Worker
1770*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1771*da0073e9SAndroid Build Coastguard Worker                return x
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker        mod = MyModule()
1774*da0073e9SAndroid Build Coastguard Worker        coeffs = [2, 3]
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1777*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.parameters()):
1778*da0073e9SAndroid Build Coastguard Worker                x += p.sum() * coeffs[idx]
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.named_parameters()):
1781*da0073e9SAndroid Build Coastguard Worker                x += p[1].sum() * coeffs[idx]
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker            return x
1784*da0073e9SAndroid Build Coastguard Worker
1785*da0073e9SAndroid Build Coastguard Worker        ref = fn(torch.ones(1))
1786*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1787*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1788*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(torch.ones(1))
1789*da0073e9SAndroid Build Coastguard Worker
1790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1792*da0073e9SAndroid Build Coastguard Worker
1793*da0073e9SAndroid Build Coastguard Worker        mod._parameters["param1"] = mod._parameters.pop("param1")
1794*da0073e9SAndroid Build Coastguard Worker        ref = fn(torch.ones(1))
1795*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(torch.ones(1))
1796*da0073e9SAndroid Build Coastguard Worker
1797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1798*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
1799*da0073e9SAndroid Build Coastguard Worker
1800*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
1801*da0073e9SAndroid Build Coastguard Worker    def test_buffer_order(self):
1802*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1803*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1804*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1805*da0073e9SAndroid Build Coastguard Worker                self.b1 = torch.nn.Buffer(torch.ones([1]))
1806*da0073e9SAndroid Build Coastguard Worker                self.b2 = torch.nn.Buffer(torch.ones([2]))
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1809*da0073e9SAndroid Build Coastguard Worker                return x
1810*da0073e9SAndroid Build Coastguard Worker
1811*da0073e9SAndroid Build Coastguard Worker        mod = MyModule()
1812*da0073e9SAndroid Build Coastguard Worker        coeffs = [2, 3]
1813*da0073e9SAndroid Build Coastguard Worker
1814*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1815*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.buffers()):
1816*da0073e9SAndroid Build Coastguard Worker                x += p.sum() * coeffs[idx]
1817*da0073e9SAndroid Build Coastguard Worker
1818*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.named_buffers()):
1819*da0073e9SAndroid Build Coastguard Worker                x += p[1].sum() * coeffs[idx]
1820*da0073e9SAndroid Build Coastguard Worker
1821*da0073e9SAndroid Build Coastguard Worker            return x
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker        ref = fn(torch.ones(1))
1824*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1825*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1826*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(torch.ones(1))
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1830*da0073e9SAndroid Build Coastguard Worker
1831*da0073e9SAndroid Build Coastguard Worker        mod._buffers["b1"] = mod._buffers.pop("b1")
1832*da0073e9SAndroid Build Coastguard Worker        ref = fn(torch.ones(1))
1833*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(torch.ones(1))
1834*da0073e9SAndroid Build Coastguard Worker
1835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
1839*da0073e9SAndroid Build Coastguard Worker    def test_module_order(self):
1840*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1841*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1842*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1843*da0073e9SAndroid Build Coastguard Worker                self.linear1 = torch.nn.Linear(3, 3)
1844*da0073e9SAndroid Build Coastguard Worker                self.linear2 = torch.nn.Linear(10, 10)
1845*da0073e9SAndroid Build Coastguard Worker
1846*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1847*da0073e9SAndroid Build Coastguard Worker                return x
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Worker        mod = MyModule()
1850*da0073e9SAndroid Build Coastguard Worker        coeffs = [2, 3, 4]
1851*da0073e9SAndroid Build Coastguard Worker
1852*da0073e9SAndroid Build Coastguard Worker        coeffs_for_mod = {mod: 10, mod.linear1: 20, mod.linear2: 30}
1853*da0073e9SAndroid Build Coastguard Worker
1854*da0073e9SAndroid Build Coastguard Worker        # Check order of _modules
1855*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1856*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.modules()):
1857*da0073e9SAndroid Build Coastguard Worker                # Something silly to force depedency on the order
1858*da0073e9SAndroid Build Coastguard Worker                x += coeffs_for_mod[p] * coeffs[idx]
1859*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.named_modules()):
1860*da0073e9SAndroid Build Coastguard Worker                x += coeffs_for_mod[p[1]] * coeffs[idx]
1861*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.children()):
1862*da0073e9SAndroid Build Coastguard Worker                x += coeffs_for_mod[p] * coeffs[idx]
1863*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(mod.named_children()):
1864*da0073e9SAndroid Build Coastguard Worker                x += coeffs_for_mod[p[1]] * coeffs[idx]
1865*da0073e9SAndroid Build Coastguard Worker            return x
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker        ref = fn(torch.ones(1))
1868*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1869*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1870*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(torch.ones(1))
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1874*da0073e9SAndroid Build Coastguard Worker
1875*da0073e9SAndroid Build Coastguard Worker        mod._modules["linear1"] = mod._modules.pop("linear1")
1876*da0073e9SAndroid Build Coastguard Worker        ref = fn(torch.ones(1))
1877*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(torch.ones(1))
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1880*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Worker    def test_attr(self):
1883*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
1884*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1885*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1886*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(10, 10)
1887*da0073e9SAndroid Build Coastguard Worker                self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
1888*da0073e9SAndroid Build Coastguard Worker
1889*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1890*da0073e9SAndroid Build Coastguard Worker                return self.r(torch.sin(x)) + self.buf0
1891*da0073e9SAndroid Build Coastguard Worker
1892*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
1893*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager")(mod)
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker        # Check parameters and buffers
1896*da0073e9SAndroid Build Coastguard Worker        for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
1897*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(id(p1) == id(p2))
1898*da0073e9SAndroid Build Coastguard Worker        for b1, b2 in zip(mod.buffers(), opt_mod.buffers()):
1899*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(id(b1) == id(b2))
1900*da0073e9SAndroid Build Coastguard Worker
1901*da0073e9SAndroid Build Coastguard Worker        def get_parameter_dtype(mod: torch.nn.Module):
1902*da0073e9SAndroid Build Coastguard Worker            parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers())
1903*da0073e9SAndroid Build Coastguard Worker            return next(parameters_and_buffers).dtype
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype)
1906*da0073e9SAndroid Build Coastguard Worker        out_dtype = opt_mod(mod)
1907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_dtype, torch.float32)
1908*da0073e9SAndroid Build Coastguard Worker
1909*da0073e9SAndroid Build Coastguard Worker    def test_dir(self):
1910*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
1911*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1912*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1913*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(10, 10)
1914*da0073e9SAndroid Build Coastguard Worker                self.buf0 = torch.nn.Buffer(torch.nn.Buffer(torch.randn(10, 10)))
1915*da0073e9SAndroid Build Coastguard Worker                self.register_parameter(
1916*da0073e9SAndroid Build Coastguard Worker                    name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
1917*da0073e9SAndroid Build Coastguard Worker                )
1918*da0073e9SAndroid Build Coastguard Worker
1919*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1920*da0073e9SAndroid Build Coastguard Worker                return self.r(torch.sin(x)) + self.buf0
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
1923*da0073e9SAndroid Build Coastguard Worker        mod_keys = dir(mod)
1924*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager")(mod)
1925*da0073e9SAndroid Build Coastguard Worker        opt_mod_keys = dir(opt_mod)
1926*da0073e9SAndroid Build Coastguard Worker
1927*da0073e9SAndroid Build Coastguard Worker        # Check user-defined attributes, parameters and buffers
1928*da0073e9SAndroid Build Coastguard Worker        self.assertIn("linear", opt_mod_keys)
1929*da0073e9SAndroid Build Coastguard Worker        self.assertIn("buf0", opt_mod_keys)
1930*da0073e9SAndroid Build Coastguard Worker        self.assertIn("param0", opt_mod_keys)
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker        # Check all attributes, parameters and buffers
1933*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(set(mod_keys).difference(opt_mod_keys)) == 0)
1934*da0073e9SAndroid Build Coastguard Worker
1935*da0073e9SAndroid Build Coastguard Worker    def test_no_recompile_on_nn_guarded_modules(self):
1936*da0073e9SAndroid Build Coastguard Worker        size = (10, 10)
1937*da0073e9SAndroid Build Coastguard Worker        cache_size_limit = 1
1938*da0073e9SAndroid Build Coastguard Worker        num_submodules = 4
1939*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker        class SubModule(torch.nn.Module):
1942*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1943*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1944*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(*size)
1945*da0073e9SAndroid Build Coastguard Worker
1946*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1947*da0073e9SAndroid Build Coastguard Worker                a = torch.sin(torch.cos(x))
1948*da0073e9SAndroid Build Coastguard Worker                return self.linear(a)
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
1951*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1952*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1953*da0073e9SAndroid Build Coastguard Worker                self.mods = [SubModule() for _ in range(num_submodules)]
1954*da0073e9SAndroid Build Coastguard Worker                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1957*da0073e9SAndroid Build Coastguard Worker                for mod in self.mods:
1958*da0073e9SAndroid Build Coastguard Worker                    x = mod(x)
1959*da0073e9SAndroid Build Coastguard Worker                return x
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
1962*da0073e9SAndroid Build Coastguard Worker        # Each submod is compiled separately and has a different nn module
1963*da0073e9SAndroid Build Coastguard Worker        # guard. Ensure that recompilation logic is handle correctly.
1964*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch(
1965*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.config.error_on_recompile", True
1966*da0073e9SAndroid Build Coastguard Worker        ), unittest.mock.patch(
1967*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.config.cache_size_limit",
1968*da0073e9SAndroid Build Coastguard Worker            cache_size_limit,
1969*da0073e9SAndroid Build Coastguard Worker        ):
1970*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(*size, requires_grad=True)
1971*da0073e9SAndroid Build Coastguard Worker            mod(x)
1972*da0073e9SAndroid Build Coastguard Worker            if torch._dynamo.config.inline_inbuilt_nn_modules:
1973*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnts.frame_count, 1)
1974*da0073e9SAndroid Build Coastguard Worker            else:
1975*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnts.frame_count, num_submodules)
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "accumulated_cache_size_limit", 2)
1978*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False)
1979*da0073e9SAndroid Build Coastguard Worker    def test_recompile_limit_on_freed_module(self):
1980*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
1981*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1982*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1983*da0073e9SAndroid Build Coastguard Worker                self.lin = torch.nn.Linear(5, 5)
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1986*da0073e9SAndroid Build Coastguard Worker                return self.lin(x)
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker        def fn(x, mod):
1989*da0073e9SAndroid Build Coastguard Worker            return mod(x)
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
1992*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(fn, backend=cnts)
1993*da0073e9SAndroid Build Coastguard Worker        for i in range(8):
1994*da0073e9SAndroid Build Coastguard Worker            mod = Mod()
1995*da0073e9SAndroid Build Coastguard Worker            opt_mod(torch.randn(5, 5), mod)
1996*da0073e9SAndroid Build Coastguard Worker
1997*da0073e9SAndroid Build Coastguard Worker        # fn compiles twice
1998*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
1999*da0073e9SAndroid Build Coastguard Worker
2000*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True)
2001*da0073e9SAndroid Build Coastguard Worker    def test_inline_inbuilt_nn_modules(self):
2002*da0073e9SAndroid Build Coastguard Worker        size = (10, 10)
2003*da0073e9SAndroid Build Coastguard Worker        cache_size_limit = 1
2004*da0073e9SAndroid Build Coastguard Worker        num_submodules = 4
2005*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
2006*da0073e9SAndroid Build Coastguard Worker
2007*da0073e9SAndroid Build Coastguard Worker        class SubModule(torch.nn.Module):
2008*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2009*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2010*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(*size)
2011*da0073e9SAndroid Build Coastguard Worker
2012*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2013*da0073e9SAndroid Build Coastguard Worker                a = torch.sin(torch.cos(x))
2014*da0073e9SAndroid Build Coastguard Worker                return self.linear(a)
2015*da0073e9SAndroid Build Coastguard Worker
2016*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
2017*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2018*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2019*da0073e9SAndroid Build Coastguard Worker                self.mods = [SubModule() for _ in range(num_submodules)]
2020*da0073e9SAndroid Build Coastguard Worker                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
2021*da0073e9SAndroid Build Coastguard Worker
2022*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2023*da0073e9SAndroid Build Coastguard Worker                for mod in self.mods:
2024*da0073e9SAndroid Build Coastguard Worker                    x = mod(x)
2025*da0073e9SAndroid Build Coastguard Worker                return x
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
2028*da0073e9SAndroid Build Coastguard Worker        # Each submod is compiled separately and has a different nn module
2029*da0073e9SAndroid Build Coastguard Worker        # guard. Ensure that recompilation logic is handle correctly.
2030*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch(
2031*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.config.error_on_recompile", True
2032*da0073e9SAndroid Build Coastguard Worker        ), unittest.mock.patch(
2033*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.config.cache_size_limit",
2034*da0073e9SAndroid Build Coastguard Worker            cache_size_limit,
2035*da0073e9SAndroid Build Coastguard Worker        ):
2036*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(*size, requires_grad=True)
2037*da0073e9SAndroid Build Coastguard Worker            mod(x)
2038*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 1)
2039*da0073e9SAndroid Build Coastguard Worker
2040*da0073e9SAndroid Build Coastguard Worker    def test_cache_size_limit_on_guarded_nn_modules(self):
2041*da0073e9SAndroid Build Coastguard Worker        cache_size_limit = 2
2042*da0073e9SAndroid Build Coastguard Worker        num_submodules = 4
2043*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
2044*da0073e9SAndroid Build Coastguard Worker
2045*da0073e9SAndroid Build Coastguard Worker        class SubModule(torch.nn.Module):
2046*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2047*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2048*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
2049*da0073e9SAndroid Build Coastguard Worker
2050*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2051*da0073e9SAndroid Build Coastguard Worker                a = torch.sin(torch.cos(x))
2052*da0073e9SAndroid Build Coastguard Worker                return self.relu(a)
2053*da0073e9SAndroid Build Coastguard Worker
2054*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
2055*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2056*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2057*da0073e9SAndroid Build Coastguard Worker                self.mods = [SubModule() for _ in range(num_submodules)]
2058*da0073e9SAndroid Build Coastguard Worker                self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
2059*da0073e9SAndroid Build Coastguard Worker
2060*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2061*da0073e9SAndroid Build Coastguard Worker                for mod in self.mods:
2062*da0073e9SAndroid Build Coastguard Worker                    x = mod(x)
2063*da0073e9SAndroid Build Coastguard Worker                return x
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
2066*da0073e9SAndroid Build Coastguard Worker        # For the third iteration, we would reach the cache size limit, and
2067*da0073e9SAndroid Build Coastguard Worker        # therefore the total number of expected frame count is 2 *
2068*da0073e9SAndroid Build Coastguard Worker        # num_submodules.
2069*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch(
2070*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.config.cache_size_limit",
2071*da0073e9SAndroid Build Coastguard Worker            cache_size_limit,
2072*da0073e9SAndroid Build Coastguard Worker        ):
2073*da0073e9SAndroid Build Coastguard Worker            for size in [
2074*da0073e9SAndroid Build Coastguard Worker                (4,),
2075*da0073e9SAndroid Build Coastguard Worker                (4, 4),
2076*da0073e9SAndroid Build Coastguard Worker                (4, 4, 4),
2077*da0073e9SAndroid Build Coastguard Worker            ]:
2078*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(size)
2079*da0073e9SAndroid Build Coastguard Worker                mod(x)
2080*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.inline_inbuilt_nn_modules:
2081*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 2)
2082*da0073e9SAndroid Build Coastguard Worker        else:
2083*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 2 * num_submodules)
2084*da0073e9SAndroid Build Coastguard Worker
2085*da0073e9SAndroid Build Coastguard Worker    def test_recursion(self):
2086*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
2087*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2088*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize(cnt)(mod)
2089*da0073e9SAndroid Build Coastguard Worker
2090*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
2091*da0073e9SAndroid Build Coastguard Worker            opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
2092*da0073e9SAndroid Build Coastguard Worker        opt_mod(torch.randn(10, 10))
2093*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2094*da0073e9SAndroid Build Coastguard Worker
2095*da0073e9SAndroid Build Coastguard Worker    def test_composition(self):
2096*da0073e9SAndroid Build Coastguard Worker        class InnerModule(torch.nn.Module):
2097*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2098*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2099*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2102*da0073e9SAndroid Build Coastguard Worker                return self.relu(torch.sin(x))
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker        opt_inner_mod = InnerModule()
2105*da0073e9SAndroid Build Coastguard Worker
2106*da0073e9SAndroid Build Coastguard Worker        class OuterModule(torch.nn.Module):
2107*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2108*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2109*da0073e9SAndroid Build Coastguard Worker                self.mod = opt_inner_mod
2110*da0073e9SAndroid Build Coastguard Worker
2111*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2112*da0073e9SAndroid Build Coastguard Worker                return self.mod(torch.cos(x))
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker        outer_mod = OuterModule()
2115*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2116*da0073e9SAndroid Build Coastguard Worker        opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
2117*da0073e9SAndroid Build Coastguard Worker
2118*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2119*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
2120*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
2121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker    def test_composition_with_opt_mod(self):
2124*da0073e9SAndroid Build Coastguard Worker        class InnerModule(torch.nn.Module):
2125*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2126*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2127*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
2128*da0073e9SAndroid Build Coastguard Worker
2129*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2130*da0073e9SAndroid Build Coastguard Worker                return self.relu(torch.sin(x))
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker        inner_mod = InnerModule()
2133*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2134*da0073e9SAndroid Build Coastguard Worker        opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)
2135*da0073e9SAndroid Build Coastguard Worker
2136*da0073e9SAndroid Build Coastguard Worker        class OuterModule(torch.nn.Module):
2137*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2138*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2139*da0073e9SAndroid Build Coastguard Worker                self.mod = opt_inner_mod
2140*da0073e9SAndroid Build Coastguard Worker
2141*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2142*da0073e9SAndroid Build Coastguard Worker                return self.mod(torch.cos(x))
2143*da0073e9SAndroid Build Coastguard Worker
2144*da0073e9SAndroid Build Coastguard Worker        outer_mod = OuterModule()
2145*da0073e9SAndroid Build Coastguard Worker        opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
2146*da0073e9SAndroid Build Coastguard Worker
2147*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2148*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
2149*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
2150*da0073e9SAndroid Build Coastguard Worker        # There will be a graph break for the inner mod being OptimizedModule
2151*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
2152*da0073e9SAndroid Build Coastguard Worker
2153*da0073e9SAndroid Build Coastguard Worker    def test_module_patch(self):
2154*da0073e9SAndroid Build Coastguard Worker        mod = ModulePatch1()
2155*da0073e9SAndroid Build Coastguard Worker        mod.forward = types.MethodType(ModulePatch2.forward, mod)
2156*da0073e9SAndroid Build Coastguard Worker
2157*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2158*da0073e9SAndroid Build Coastguard Worker            return mod(x)
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
2161*da0073e9SAndroid Build Coastguard Worker            torch.allclose(
2162*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)),
2163*da0073e9SAndroid Build Coastguard Worker                torch.zeros(1),
2164*da0073e9SAndroid Build Coastguard Worker            )
2165*da0073e9SAndroid Build Coastguard Worker        )
2166*da0073e9SAndroid Build Coastguard Worker
2167*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
2168*da0073e9SAndroid Build Coastguard Worker    def test_hooks_outer(self):
2169*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2170*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
2171*da0073e9SAndroid Build Coastguard Worker                return 2 * x + 1
2172*da0073e9SAndroid Build Coastguard Worker
2173*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
2174*da0073e9SAndroid Build Coastguard Worker
2175*da0073e9SAndroid Build Coastguard Worker        def forward_hook(
2176*da0073e9SAndroid Build Coastguard Worker            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2177*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Tensor:
2178*da0073e9SAndroid Build Coastguard Worker            return 2 * output + 1
2179*da0073e9SAndroid Build Coastguard Worker
2180*da0073e9SAndroid Build Coastguard Worker        handle = m.register_forward_hook(forward_hook)
2181*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor(1.0, requires_grad=True)
2182*da0073e9SAndroid Build Coastguard Worker
2183*da0073e9SAndroid Build Coastguard Worker        failure_reason = None
2184*da0073e9SAndroid Build Coastguard Worker
2185*da0073e9SAndroid Build Coastguard Worker        def guard_fail_fn(failure):
2186*da0073e9SAndroid Build Coastguard Worker            nonlocal failure_reason
2187*da0073e9SAndroid Build Coastguard Worker            failure_reason = failure[0]
2188*da0073e9SAndroid Build Coastguard Worker
2189*da0073e9SAndroid Build Coastguard Worker        compiled_m = torch._dynamo.optimize(
2190*da0073e9SAndroid Build Coastguard Worker            guard_fail_fn=guard_fail_fn, backend="eager"
2191*da0073e9SAndroid Build Coastguard Worker        )(m)
2192*da0073e9SAndroid Build Coastguard Worker
2193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_m(inp), m(inp))
2194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_m(inp).item(), 7)
2195*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(failure_reason is None)
2196*da0073e9SAndroid Build Coastguard Worker
2197*da0073e9SAndroid Build Coastguard Worker        # what if we remove our hook? we should recompile?
2198*da0073e9SAndroid Build Coastguard Worker        handle.remove()
2199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_m(inp), m(inp))
2200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_m(inp).item(), 3)
2201*da0073e9SAndroid Build Coastguard Worker        # self.assertTrue(failure_reason == "hook")
2202*da0073e9SAndroid Build Coastguard Worker
2203*da0073e9SAndroid Build Coastguard Worker        """
2204*da0073e9SAndroid Build Coastguard Worker        Summary:
2205*da0073e9SAndroid Build Coastguard Worker          - removing a hook doesn't fail a guard, because we weren't compiling the hook
2206*da0073e9SAndroid Build Coastguard Worker            (at least into the same graph) as forward in the first place! We do correctly
2207*da0073e9SAndroid Build Coastguard Worker            omit calling the removed hook, but since this hook is a post forward hook,
2208*da0073e9SAndroid Build Coastguard Worker            the 'RETURN' from forward is breaking the graph.
2209*da0073e9SAndroid Build Coastguard Worker
2210*da0073e9SAndroid Build Coastguard Worker            Why is 'forward' the entrypoint to an InstructionTranslator, after I changed
2211*da0073e9SAndroid Build Coastguard Worker            the eval_frame entrypoint to Module.__call__?
2212*da0073e9SAndroid Build Coastguard Worker        """
2213*da0073e9SAndroid Build Coastguard Worker
2214*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
2215*da0073e9SAndroid Build Coastguard Worker    def test_hooks_inner(self):
2216*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2217*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
2218*da0073e9SAndroid Build Coastguard Worker                return 2 * x + 1
2219*da0073e9SAndroid Build Coastguard Worker
2220*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
2221*da0073e9SAndroid Build Coastguard Worker
2222*da0073e9SAndroid Build Coastguard Worker        def forward_hook(
2223*da0073e9SAndroid Build Coastguard Worker            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2224*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Tensor:
2225*da0073e9SAndroid Build Coastguard Worker            return 2 * output + 1
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker        handle = m.register_forward_hook(forward_hook)
2228*da0073e9SAndroid Build Coastguard Worker
2229*da0073e9SAndroid Build Coastguard Worker        def outer_func(tensor):
2230*da0073e9SAndroid Build Coastguard Worker            x = tensor * 2 + 1
2231*da0073e9SAndroid Build Coastguard Worker            y = m(x)
2232*da0073e9SAndroid Build Coastguard Worker            return y
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor(1.0, requires_grad=True)
2235*da0073e9SAndroid Build Coastguard Worker
2236*da0073e9SAndroid Build Coastguard Worker        failure_reason = None
2237*da0073e9SAndroid Build Coastguard Worker
2238*da0073e9SAndroid Build Coastguard Worker        def guard_fail_fn(failure):
2239*da0073e9SAndroid Build Coastguard Worker            nonlocal failure_reason
2240*da0073e9SAndroid Build Coastguard Worker            failure_reason = failure[0]
2241*da0073e9SAndroid Build Coastguard Worker
2242*da0073e9SAndroid Build Coastguard Worker        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
2243*da0073e9SAndroid Build Coastguard Worker        compiled_func = torch._dynamo.optimize(
2244*da0073e9SAndroid Build Coastguard Worker            guard_fail_fn=guard_fail_fn,
2245*da0073e9SAndroid Build Coastguard Worker            backend=cc,
2246*da0073e9SAndroid Build Coastguard Worker        )(outer_func)
2247*da0073e9SAndroid Build Coastguard Worker
2248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp), outer_func(inp))
2249*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp).item(), 15)
2250*da0073e9SAndroid Build Coastguard Worker
2251*da0073e9SAndroid Build Coastguard Worker        # We are compiling 1 big graph for all 3 functions including the hook.
2252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cc.frame_count, 1)
2253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cc.op_count, 6)
2254*da0073e9SAndroid Build Coastguard Worker
2255*da0073e9SAndroid Build Coastguard Worker        # If we remove the hook, we should recompile
2256*da0073e9SAndroid Build Coastguard Worker        handle.remove()
2257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp), outer_func(inp))
2258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp).item(), 7)
2259*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("forward_hooks" in failure_reason)
2260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cc.frame_count, 1 + 1)
2261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cc.op_count, 6 + 4)
2262*da0073e9SAndroid Build Coastguard Worker
2263*da0073e9SAndroid Build Coastguard Worker        # what if instead of removing, we alter our hook?
2264*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
2265*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
2266*da0073e9SAndroid Build Coastguard Worker        handle = m.register_forward_hook(forward_hook)
2267*da0073e9SAndroid Build Coastguard Worker        failure_reason = None
2268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp), outer_func(inp))
2269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp).item(), 15)
2270*da0073e9SAndroid Build Coastguard Worker
2271*da0073e9SAndroid Build Coastguard Worker        def new_forward_hook(
2272*da0073e9SAndroid Build Coastguard Worker            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2273*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Tensor:
2274*da0073e9SAndroid Build Coastguard Worker            return 2 * output + 2
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker        m._forward_hooks[handle.id] = new_forward_hook
2277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp), outer_func(inp))
2278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp).item(), 16)
2279*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(failure_reason, r"___check_obj_id\(L\['m'\]._forward_hooks")
2280*da0073e9SAndroid Build Coastguard Worker
2281*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "guard_nn_modules", False)
2282*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
2283*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False)
2284*da0073e9SAndroid Build Coastguard Worker    def test_hooks_skip_guards(self):
2285*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2286*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
2287*da0073e9SAndroid Build Coastguard Worker                return 2 * x + 1
2288*da0073e9SAndroid Build Coastguard Worker
2289*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
2290*da0073e9SAndroid Build Coastguard Worker
2291*da0073e9SAndroid Build Coastguard Worker        def forward_hook(
2292*da0073e9SAndroid Build Coastguard Worker            module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
2293*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Tensor:
2294*da0073e9SAndroid Build Coastguard Worker            return 2 * output + 1
2295*da0073e9SAndroid Build Coastguard Worker
2296*da0073e9SAndroid Build Coastguard Worker        handle = m.register_forward_hook(forward_hook)
2297*da0073e9SAndroid Build Coastguard Worker
2298*da0073e9SAndroid Build Coastguard Worker        def outer_func(tensor):
2299*da0073e9SAndroid Build Coastguard Worker            x = tensor * 2 + 1
2300*da0073e9SAndroid Build Coastguard Worker            y = m(x)
2301*da0073e9SAndroid Build Coastguard Worker            return y
2302*da0073e9SAndroid Build Coastguard Worker
2303*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor(1.0, requires_grad=True)
2304*da0073e9SAndroid Build Coastguard Worker
2305*da0073e9SAndroid Build Coastguard Worker        failure_reason = None
2306*da0073e9SAndroid Build Coastguard Worker
2307*da0073e9SAndroid Build Coastguard Worker        def guard_fail_fn(failure):
2308*da0073e9SAndroid Build Coastguard Worker            nonlocal failure_reason
2309*da0073e9SAndroid Build Coastguard Worker            failure_reason = failure[0]
2310*da0073e9SAndroid Build Coastguard Worker
2311*da0073e9SAndroid Build Coastguard Worker        cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
2312*da0073e9SAndroid Build Coastguard Worker        compiled_func = torch._dynamo.optimize(
2313*da0073e9SAndroid Build Coastguard Worker            guard_fail_fn=guard_fail_fn,
2314*da0073e9SAndroid Build Coastguard Worker            backend=cc,
2315*da0073e9SAndroid Build Coastguard Worker        )(outer_func)
2316*da0073e9SAndroid Build Coastguard Worker
2317*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
2318*da0073e9SAndroid Build Coastguard Worker        handle = m.register_forward_hook(forward_hook)
2319*da0073e9SAndroid Build Coastguard Worker        failure_reason = None
2320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp), outer_func(inp))
2321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp).item(), 15)
2322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cc.frame_count, 1)
2323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cc.op_count, 6)
2324*da0073e9SAndroid Build Coastguard Worker
2325*da0073e9SAndroid Build Coastguard Worker        # if we remove the hook, dynamo shouldn't notice
2326*da0073e9SAndroid Build Coastguard Worker        handle.remove()
2327*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(compiled_func(inp), outer_func(inp))
2328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compiled_func(inp).item(), 15)
2329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cc.frame_count, 1)
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker    def _forward_hook_test_helper(self, model):
2332*da0073e9SAndroid Build Coastguard Worker        forward_handles = {}
2333*da0073e9SAndroid Build Coastguard Worker        compiled_activations = {}
2334*da0073e9SAndroid Build Coastguard Worker        eager_activations = {}
2335*da0073e9SAndroid Build Coastguard Worker        activations = None
2336*da0073e9SAndroid Build Coastguard Worker
2337*da0073e9SAndroid Build Coastguard Worker        def save_activations(name, mod, inp, out):
2338*da0073e9SAndroid Build Coastguard Worker            activations[name] = inp
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker        for name, module in model.named_modules():
2341*da0073e9SAndroid Build Coastguard Worker            forward_handles[name] = module.register_forward_hook(
2342*da0073e9SAndroid Build Coastguard Worker                partial(save_activations, name)
2343*da0073e9SAndroid Build Coastguard Worker            )
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker        compiled_model = torch.compile(model, backend="aot_eager")
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker        activations = compiled_activations
2348*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2349*da0073e9SAndroid Build Coastguard Worker            # second iteration is key, hooks would have fired during aot trace
2350*da0073e9SAndroid Build Coastguard Worker            # on first iter
2351*da0073e9SAndroid Build Coastguard Worker            compiled_activations.clear()
2352*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((20, 10))
2353*da0073e9SAndroid Build Coastguard Worker            pred = compiled_model(x)
2354*da0073e9SAndroid Build Coastguard Worker            loss = pred.sum()
2355*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker        activations = eager_activations
2358*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2359*da0073e9SAndroid Build Coastguard Worker            # second iteration is key, hooks would have fired during aot trace
2360*da0073e9SAndroid Build Coastguard Worker            # on first iter
2361*da0073e9SAndroid Build Coastguard Worker            eager_activations.clear()
2362*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((20, 10))
2363*da0073e9SAndroid Build Coastguard Worker            pred = model(x)
2364*da0073e9SAndroid Build Coastguard Worker            loss = pred.sum()
2365*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2366*da0073e9SAndroid Build Coastguard Worker
2367*da0073e9SAndroid Build Coastguard Worker        print(f"Recorded Layers: {compiled_activations.keys()}\n\n")
2368*da0073e9SAndroid Build Coastguard Worker        print(f"Expected Layers: {eager_activations.keys()}")
2369*da0073e9SAndroid Build Coastguard Worker
2370*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compiled_activations.keys() == eager_activations.keys())
2371*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(activations.keys() == forward_handles.keys())
2372*da0073e9SAndroid Build Coastguard Worker
2373*da0073e9SAndroid Build Coastguard Worker    def test_hooks_allowed_modules(self):
2374*da0073e9SAndroid Build Coastguard Worker        # this test shouldn't care whether hook guards are enabled or not
2375*da0073e9SAndroid Build Coastguard Worker        class ToyModel(torch.nn.Module):
2376*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2377*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2378*da0073e9SAndroid Build Coastguard Worker                self.net = torch.nn.Sequential(
2379*da0073e9SAndroid Build Coastguard Worker                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
2380*da0073e9SAndroid Build Coastguard Worker                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
2381*da0073e9SAndroid Build Coastguard Worker                )
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2384*da0073e9SAndroid Build Coastguard Worker                return self.net(x)
2385*da0073e9SAndroid Build Coastguard Worker
2386*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
2387*da0073e9SAndroid Build Coastguard Worker        self._forward_hook_test_helper(model)
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker    def test_hooks_allowed_modules_compiles(self):
2390*da0073e9SAndroid Build Coastguard Worker        class ToyModel(torch.nn.Module):
2391*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2392*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2393*da0073e9SAndroid Build Coastguard Worker                self.net = torch.nn.Sequential(
2394*da0073e9SAndroid Build Coastguard Worker                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
2395*da0073e9SAndroid Build Coastguard Worker                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
2396*da0073e9SAndroid Build Coastguard Worker                )
2397*da0073e9SAndroid Build Coastguard Worker
2398*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2399*da0073e9SAndroid Build Coastguard Worker                return self.net(x)
2400*da0073e9SAndroid Build Coastguard Worker
2401*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
2402*da0073e9SAndroid Build Coastguard Worker        activations = []
2403*da0073e9SAndroid Build Coastguard Worker
2404*da0073e9SAndroid Build Coastguard Worker        def save_activations(mod, inp, out):
2405*da0073e9SAndroid Build Coastguard Worker            activations.append(inp)
2406*da0073e9SAndroid Build Coastguard Worker
2407*da0073e9SAndroid Build Coastguard Worker        for name, module in model.named_modules():
2408*da0073e9SAndroid Build Coastguard Worker            module.register_forward_hook(save_activations)
2409*da0073e9SAndroid Build Coastguard Worker
2410*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2411*da0073e9SAndroid Build Coastguard Worker        model = torch._dynamo.optimize(cnt, nopython=True)(model)
2412*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2413*da0073e9SAndroid Build Coastguard Worker            # second iteration is key, hooks would have fired during aot trace
2414*da0073e9SAndroid Build Coastguard Worker            # on first iter
2415*da0073e9SAndroid Build Coastguard Worker            activations.clear()
2416*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((20, 10))
2417*da0073e9SAndroid Build Coastguard Worker            pred = model(x)
2418*da0073e9SAndroid Build Coastguard Worker            loss = pred.sum()
2419*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(activations), 6)
2421*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2422*da0073e9SAndroid Build Coastguard Worker
2423*da0073e9SAndroid Build Coastguard Worker    def test_hooks_allowed_modules_compiles_self_contained(self):
2424*da0073e9SAndroid Build Coastguard Worker        class ToyModel(torch.nn.Module):
2425*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2426*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2427*da0073e9SAndroid Build Coastguard Worker                self.net = torch.nn.Sequential(
2428*da0073e9SAndroid Build Coastguard Worker                    *[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
2429*da0073e9SAndroid Build Coastguard Worker                    + [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
2430*da0073e9SAndroid Build Coastguard Worker                )
2431*da0073e9SAndroid Build Coastguard Worker
2432*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2433*da0073e9SAndroid Build Coastguard Worker                return self.net(x) * self.net(x)
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
2436*da0073e9SAndroid Build Coastguard Worker        forward_handles = {}
2437*da0073e9SAndroid Build Coastguard Worker
2438*da0073e9SAndroid Build Coastguard Worker        def output_modifying_hook(mod, inp, out):
2439*da0073e9SAndroid Build Coastguard Worker            return 2 * out + 1
2440*da0073e9SAndroid Build Coastguard Worker
2441*da0073e9SAndroid Build Coastguard Worker        for name, module in model.named_modules():
2442*da0073e9SAndroid Build Coastguard Worker            forward_handles[name] = module.register_forward_hook(output_modifying_hook)
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2445*da0073e9SAndroid Build Coastguard Worker
2446*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((20, 10))
2447*da0073e9SAndroid Build Coastguard Worker        pred_eager = model(x)
2448*da0073e9SAndroid Build Coastguard Worker        loss_eager = pred_eager.sum()
2449*da0073e9SAndroid Build Coastguard Worker        eager_loss_bwd = loss_eager.backward()
2450*da0073e9SAndroid Build Coastguard Worker
2451*da0073e9SAndroid Build Coastguard Worker        model = torch._dynamo.optimize(cnt, nopython=True)(model)
2452*da0073e9SAndroid Build Coastguard Worker        pred = model(x)
2453*da0073e9SAndroid Build Coastguard Worker
2454*da0073e9SAndroid Build Coastguard Worker        loss = pred.sum()
2455*da0073e9SAndroid Build Coastguard Worker        loss_bwd = loss.backward()
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_loss_bwd, loss_bwd)
2458*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
2459*da0073e9SAndroid Build Coastguard Worker
2460*da0073e9SAndroid Build Coastguard Worker        # Ndim change, recompile
2461*da0073e9SAndroid Build Coastguard Worker        pred = model(torch.randn([10, 10, 10]))
2462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 4)
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker        # Stable
2465*da0073e9SAndroid Build Coastguard Worker        pred = model(torch.randn([10, 10, 10]))
2466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 4)
2467*da0073e9SAndroid Build Coastguard Worker
2468*da0073e9SAndroid Build Coastguard Worker    def test_dunder_call_explicitly(self):
2469*da0073e9SAndroid Build Coastguard Worker        # hooks should be triggered if explicit calling `__call__`
2470*da0073e9SAndroid Build Coastguard Worker        class ToyModel(torch.nn.Module):
2471*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2472*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2473*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(10, 10000)
2474*da0073e9SAndroid Build Coastguard Worker
2475*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2476*da0073e9SAndroid Build Coastguard Worker                return self.linear.__call__(x)
2477*da0073e9SAndroid Build Coastguard Worker
2478*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
2479*da0073e9SAndroid Build Coastguard Worker        self._forward_hook_test_helper(model)
2480*da0073e9SAndroid Build Coastguard Worker
2481*da0073e9SAndroid Build Coastguard Worker    def test_backward_hooks(self):
2482*da0073e9SAndroid Build Coastguard Worker        # this test shouldn't care whether hook guards are enabled or not
2483*da0073e9SAndroid Build Coastguard Worker
2484*da0073e9SAndroid Build Coastguard Worker        class CustomLinear(torch.nn.Module):
2485*da0073e9SAndroid Build Coastguard Worker            # not an 'allowed module', so should not graph-break
2486*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a, b):
2487*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2488*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.randn(a, b))
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2491*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.weight)
2492*da0073e9SAndroid Build Coastguard Worker
2493*da0073e9SAndroid Build Coastguard Worker        class ToyModel(torch.nn.Module):
2494*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2495*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2496*da0073e9SAndroid Build Coastguard Worker                self.net = torch.nn.Sequential(
2497*da0073e9SAndroid Build Coastguard Worker                    *[CustomLinear(10, 10)]
2498*da0073e9SAndroid Build Coastguard Worker                    + [CustomLinear(10, 10000)]
2499*da0073e9SAndroid Build Coastguard Worker                    + [CustomLinear(10000, 5)]
2500*da0073e9SAndroid Build Coastguard Worker                )
2501*da0073e9SAndroid Build Coastguard Worker
2502*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2503*da0073e9SAndroid Build Coastguard Worker                return self.net(x)
2504*da0073e9SAndroid Build Coastguard Worker
2505*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
2506*da0073e9SAndroid Build Coastguard Worker        backward_hook_handles = {}
2507*da0073e9SAndroid Build Coastguard Worker        pre_backward_hook_handles = {}
2508*da0073e9SAndroid Build Coastguard Worker
2509*da0073e9SAndroid Build Coastguard Worker        grad_sizes = {}
2510*da0073e9SAndroid Build Coastguard Worker
2511*da0073e9SAndroid Build Coastguard Worker        def backward_hook(name, mod, grad_inp, grad_out):
2512*da0073e9SAndroid Build Coastguard Worker            grad_sizes[name] = (
2513*da0073e9SAndroid Build Coastguard Worker                (gi.shape for gi in grad_inp),
2514*da0073e9SAndroid Build Coastguard Worker                (go.shape for go in grad_out),
2515*da0073e9SAndroid Build Coastguard Worker            )
2516*da0073e9SAndroid Build Coastguard Worker            return None
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker        pre_grad_sizes = {}
2519*da0073e9SAndroid Build Coastguard Worker
2520*da0073e9SAndroid Build Coastguard Worker        def backward_pre_hook(name, mod, grad_out):
2521*da0073e9SAndroid Build Coastguard Worker            pre_grad_sizes[name] = (go.shape for go in grad_out)
2522*da0073e9SAndroid Build Coastguard Worker            return None
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker        for name, module in model.named_modules():
2525*da0073e9SAndroid Build Coastguard Worker            backward_hook_handles[name] = module.register_full_backward_hook(
2526*da0073e9SAndroid Build Coastguard Worker                partial(backward_hook, name)
2527*da0073e9SAndroid Build Coastguard Worker            )
2528*da0073e9SAndroid Build Coastguard Worker
2529*da0073e9SAndroid Build Coastguard Worker            pre_backward_hook_handles[name] = module.register_full_backward_pre_hook(
2530*da0073e9SAndroid Build Coastguard Worker                partial(backward_pre_hook, name)
2531*da0073e9SAndroid Build Coastguard Worker            )
2532*da0073e9SAndroid Build Coastguard Worker
2533*da0073e9SAndroid Build Coastguard Worker        model = torch.compile(model, backend="aot_eager")
2534*da0073e9SAndroid Build Coastguard Worker
2535*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2536*da0073e9SAndroid Build Coastguard Worker            # second iteration is key, hooks would have fired during aot trace
2537*da0073e9SAndroid Build Coastguard Worker            # on first iter
2538*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((20, 10))
2539*da0073e9SAndroid Build Coastguard Worker            pred = model(x)
2540*da0073e9SAndroid Build Coastguard Worker            loss = pred.sum()
2541*da0073e9SAndroid Build Coastguard Worker            loss.backward()
2542*da0073e9SAndroid Build Coastguard Worker
2543*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys())
2544*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys())
2545*da0073e9SAndroid Build Coastguard Worker
2546*da0073e9SAndroid Build Coastguard Worker    def test_udo_instance_method_as_hook(self):
2547*da0073e9SAndroid Build Coastguard Worker        class CustomClass:
2548*da0073e9SAndroid Build Coastguard Worker            def __init__(self, module):
2549*da0073e9SAndroid Build Coastguard Worker                self.module = module
2550*da0073e9SAndroid Build Coastguard Worker                self.handle = self.module.register_forward_pre_hook(
2551*da0073e9SAndroid Build Coastguard Worker                    self.func1, prepend=True, with_kwargs=True
2552*da0073e9SAndroid Build Coastguard Worker                )
2553*da0073e9SAndroid Build Coastguard Worker
2554*da0073e9SAndroid Build Coastguard Worker            def func1(self, module, args, kwargs):
2555*da0073e9SAndroid Build Coastguard Worker                return (args[0] + 1,), kwargs
2556*da0073e9SAndroid Build Coastguard Worker
2557*da0073e9SAndroid Build Coastguard Worker            def __call__(self, x):
2558*da0073e9SAndroid Build Coastguard Worker                return self.module(x)
2559*da0073e9SAndroid Build Coastguard Worker
2560*da0073e9SAndroid Build Coastguard Worker        class ToyModel(torch.nn.Module):
2561*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2562*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2563*da0073e9SAndroid Build Coastguard Worker
2564*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2565*da0073e9SAndroid Build Coastguard Worker                return x * x
2566*da0073e9SAndroid Build Coastguard Worker
2567*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
2568*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros((3, 4))
2569*da0073e9SAndroid Build Coastguard Worker        obj = CustomClass(model)
2570*da0073e9SAndroid Build Coastguard Worker        out = torch.compile(obj, fullgraph=True)(x)
2571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, (x + 1) * (x + 1))
2572*da0073e9SAndroid Build Coastguard Worker
2573*da0073e9SAndroid Build Coastguard Worker    def test_module_dict_iter_name(self):
2574*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2575*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2576*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2577*da0073e9SAndroid Build Coastguard Worker                self.activations = torch.nn.ModuleDict(
2578*da0073e9SAndroid Build Coastguard Worker                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
2579*da0073e9SAndroid Build Coastguard Worker                )
2580*da0073e9SAndroid Build Coastguard Worker
2581*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2582*da0073e9SAndroid Build Coastguard Worker                for activation_name in self.activations:
2583*da0073e9SAndroid Build Coastguard Worker                    x = self.activations[activation_name](x)
2584*da0073e9SAndroid Build Coastguard Worker                return x
2585*da0073e9SAndroid Build Coastguard Worker
2586*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2587*da0073e9SAndroid Build Coastguard Worker        # Eager
2588*da0073e9SAndroid Build Coastguard Worker        eager_res = MyModule()(torch.ones(10, 10))
2589*da0073e9SAndroid Build Coastguard Worker
2590*da0073e9SAndroid Build Coastguard Worker        # Compile
2591*da0073e9SAndroid Build Coastguard Worker        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
2592*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_res, optim_res)
2593*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2594*da0073e9SAndroid Build Coastguard Worker
2595*da0073e9SAndroid Build Coastguard Worker    def test_module_dict_iter_keys(self):
2596*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2597*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2598*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2599*da0073e9SAndroid Build Coastguard Worker                self.activations = torch.nn.ModuleDict(
2600*da0073e9SAndroid Build Coastguard Worker                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
2601*da0073e9SAndroid Build Coastguard Worker                )
2602*da0073e9SAndroid Build Coastguard Worker
2603*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2604*da0073e9SAndroid Build Coastguard Worker                for activation_name in self.activations.keys():
2605*da0073e9SAndroid Build Coastguard Worker                    x = self.activations[activation_name](x)
2606*da0073e9SAndroid Build Coastguard Worker                return x
2607*da0073e9SAndroid Build Coastguard Worker
2608*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2609*da0073e9SAndroid Build Coastguard Worker        # Eager
2610*da0073e9SAndroid Build Coastguard Worker        eager_res = MyModule()(torch.ones(10, 10))
2611*da0073e9SAndroid Build Coastguard Worker
2612*da0073e9SAndroid Build Coastguard Worker        # Compile
2613*da0073e9SAndroid Build Coastguard Worker        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
2614*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_res, optim_res)
2615*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2616*da0073e9SAndroid Build Coastguard Worker
2617*da0073e9SAndroid Build Coastguard Worker    def test_module_setattr(self):
2618*da0073e9SAndroid Build Coastguard Worker        models = torch.nn.Sequential(torch.nn.Linear(3, 3))
2619*da0073e9SAndroid Build Coastguard Worker        models[0].abc = False
2620*da0073e9SAndroid Build Coastguard Worker
2621*da0073e9SAndroid Build Coastguard Worker        def run():
2622*da0073e9SAndroid Build Coastguard Worker            models[0].abc = True
2623*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, 3)
2624*da0073e9SAndroid Build Coastguard Worker            return models(x)
2625*da0073e9SAndroid Build Coastguard Worker
2626*da0073e9SAndroid Build Coastguard Worker        run = torch.compile(run, fullgraph=True)
2627*da0073e9SAndroid Build Coastguard Worker        run()
2628*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(models[0].abc)
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker    def test_assign_does_not_exist(self):
2631*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2632*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2633*da0073e9SAndroid Build Coastguard Worker                self.text_encoding = x + 1
2634*da0073e9SAndroid Build Coastguard Worker                return self.text_encoding
2635*da0073e9SAndroid Build Coastguard Worker
2636*da0073e9SAndroid Build Coastguard Worker        mod = MyModule()
2637*da0073e9SAndroid Build Coastguard Worker        out = torch.compile(mod, fullgraph=True)(torch.randn(10))
2638*da0073e9SAndroid Build Coastguard Worker        assert mod.text_encoding is out
2639*da0073e9SAndroid Build Coastguard Worker
2640*da0073e9SAndroid Build Coastguard Worker    def test_module_dict_iter_values(self):
2641*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2642*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2643*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2644*da0073e9SAndroid Build Coastguard Worker                self.activations = torch.nn.ModuleDict(
2645*da0073e9SAndroid Build Coastguard Worker                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
2646*da0073e9SAndroid Build Coastguard Worker                )
2647*da0073e9SAndroid Build Coastguard Worker
2648*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2649*da0073e9SAndroid Build Coastguard Worker                for activation in self.activations.values():
2650*da0073e9SAndroid Build Coastguard Worker                    x = activation(x)
2651*da0073e9SAndroid Build Coastguard Worker                return x
2652*da0073e9SAndroid Build Coastguard Worker
2653*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2654*da0073e9SAndroid Build Coastguard Worker        # Eager
2655*da0073e9SAndroid Build Coastguard Worker        eager_res = MyModule()(torch.ones(10, 10))
2656*da0073e9SAndroid Build Coastguard Worker
2657*da0073e9SAndroid Build Coastguard Worker        # Compile
2658*da0073e9SAndroid Build Coastguard Worker        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
2659*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_res, optim_res)
2660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2661*da0073e9SAndroid Build Coastguard Worker
2662*da0073e9SAndroid Build Coastguard Worker    def test_unspecialized_seq(self):
2663*da0073e9SAndroid Build Coastguard Worker        models = torch.nn.Sequential(torch.nn.Linear(3, 3))
2664*da0073e9SAndroid Build Coastguard Worker
2665*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2666*da0073e9SAndroid Build Coastguard Worker            models[0].training = False
2667*da0073e9SAndroid Build Coastguard Worker            return models(x)
2668*da0073e9SAndroid Build Coastguard Worker
2669*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
2670*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 3)
2671*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
2672*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
2673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
2674*da0073e9SAndroid Build Coastguard Worker
2675*da0073e9SAndroid Build Coastguard Worker    def test_no_op_assignment(self):
2676*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
2677*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2678*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2679*da0073e9SAndroid Build Coastguard Worker                self.buffer = torch.rand([4])
2680*da0073e9SAndroid Build Coastguard Worker
2681*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2682*da0073e9SAndroid Build Coastguard Worker                # should be a no-op, but causes dynamo to lose the static input
2683*da0073e9SAndroid Build Coastguard Worker                x = x + 1
2684*da0073e9SAndroid Build Coastguard Worker                self.buffer = self.buffer.to(x)
2685*da0073e9SAndroid Build Coastguard Worker                return self.buffer + x
2686*da0073e9SAndroid Build Coastguard Worker
2687*da0073e9SAndroid Build Coastguard Worker        compiles_without_buffers = 0
2688*da0073e9SAndroid Build Coastguard Worker
2689*da0073e9SAndroid Build Coastguard Worker        def debug_compile(gm, *args, **kwargs):
2690*da0073e9SAndroid Build Coastguard Worker            nonlocal compiles_without_buffers
2691*da0073e9SAndroid Build Coastguard Worker            compiles_without_buffers += len(list(gm.buffers())) == 0
2692*da0073e9SAndroid Build Coastguard Worker            return gm
2693*da0073e9SAndroid Build Coastguard Worker
2694*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=debug_compile)
2695*da0073e9SAndroid Build Coastguard Worker        def foo(mod, x):
2696*da0073e9SAndroid Build Coastguard Worker            return mod(x)
2697*da0073e9SAndroid Build Coastguard Worker
2698*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
2699*da0073e9SAndroid Build Coastguard Worker        foo(mod, torch.rand([4]))
2700*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.inline_inbuilt_nn_modules:
2701*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiles_without_buffers, 1)
2702*da0073e9SAndroid Build Coastguard Worker        else:
2703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiles_without_buffers, 0)
2704*da0073e9SAndroid Build Coastguard Worker
2705*da0073e9SAndroid Build Coastguard Worker        foo(mod, torch.rand([4], dtype=torch.half))
2706*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.inline_inbuilt_nn_modules:
2707*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiles_without_buffers, 2)
2708*da0073e9SAndroid Build Coastguard Worker        else:
2709*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiles_without_buffers, 1)
2710*da0073e9SAndroid Build Coastguard Worker
2711*da0073e9SAndroid Build Coastguard Worker        class Mod2(Mod):
2712*da0073e9SAndroid Build Coastguard Worker            def __setattr__(self, name, value):
2713*da0073e9SAndroid Build Coastguard Worker                return super().__setattr__(name, value)
2714*da0073e9SAndroid Build Coastguard Worker
2715*da0073e9SAndroid Build Coastguard Worker        foo(Mod2(), torch.rand([4]))
2716*da0073e9SAndroid Build Coastguard Worker        # causes two compilations, bc unimplemented custom setattr
2717*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compiles_without_buffers >= 2)
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker    def test_unspec_non_inlinable_module(self):
2720*da0073e9SAndroid Build Coastguard Worker        mod = UnspecNonInlinableModule()
2721*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(mod)
2722*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(100)
2723*da0073e9SAndroid Build Coastguard Worker        actual = opt_fn(x)
2724*da0073e9SAndroid Build Coastguard Worker        expected = mod(x)
2725*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
2726*da0073e9SAndroid Build Coastguard Worker
2727*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2728*da0073e9SAndroid Build Coastguard Worker    def test_mark_static_previously_seen_tensor(self):
2729*da0073e9SAndroid Build Coastguard Worker        # This test verifies that dynamo will mark
2730*da0073e9SAndroid Build Coastguard Worker        # the buffers/params of a module as static
2731*da0073e9SAndroid Build Coastguard Worker        # even if this param was previously seen
2732*da0073e9SAndroid Build Coastguard Worker        # (ex. as a different input)
2733*da0073e9SAndroid Build Coastguard Worker        num_compiles = 0
2734*da0073e9SAndroid Build Coastguard Worker
2735*da0073e9SAndroid Build Coastguard Worker        def debug_compiler(gm, _):
2736*da0073e9SAndroid Build Coastguard Worker            nonlocal num_compiles
2737*da0073e9SAndroid Build Coastguard Worker            num_compiles += 1
2738*da0073e9SAndroid Build Coastguard Worker
2739*da0073e9SAndroid Build Coastguard Worker            input_nodes = [
2740*da0073e9SAndroid Build Coastguard Worker                n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
2741*da0073e9SAndroid Build Coastguard Worker            ]
2742*da0073e9SAndroid Build Coastguard Worker
2743*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(len(input_nodes), 0)
2744*da0073e9SAndroid Build Coastguard Worker            for input_node in input_nodes:
2745*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2746*da0073e9SAndroid Build Coastguard Worker                    input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
2747*da0073e9SAndroid Build Coastguard Worker                    "unguarded",
2748*da0073e9SAndroid Build Coastguard Worker                )
2749*da0073e9SAndroid Build Coastguard Worker
2750*da0073e9SAndroid Build Coastguard Worker            return gm
2751*da0073e9SAndroid Build Coastguard Worker
2752*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2753*da0073e9SAndroid Build Coastguard Worker            def __init__(self, buf) -> None:
2754*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2755*da0073e9SAndroid Build Coastguard Worker                # Changing this one to nn.Buffer fails because `nn.Buffer` does a .detach()
2756*da0073e9SAndroid Build Coastguard Worker                # so the value in self.tx.output.side_effects will no longer evaluate to True
2757*da0073e9SAndroid Build Coastguard Worker                self.register_buffer("buf", buf)
2758*da0073e9SAndroid Build Coastguard Worker
2759*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2760*da0073e9SAndroid Build Coastguard Worker                return self.buf * x
2761*da0073e9SAndroid Build Coastguard Worker
2762*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(backend=debug_compiler)
2763*da0073e9SAndroid Build Coastguard Worker        def fn(x, b, mod):
2764*da0073e9SAndroid Build Coastguard Worker            z = b + 1
2765*da0073e9SAndroid Build Coastguard Worker            return z * mod(x)
2766*da0073e9SAndroid Build Coastguard Worker
2767*da0073e9SAndroid Build Coastguard Worker        buf = torch.ones(2, 2)
2768*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2)
2769*da0073e9SAndroid Build Coastguard Worker        mod = TestModule(buf)
2770*da0073e9SAndroid Build Coastguard Worker        fn(inp, buf, mod)
2771*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_compiles, 1)
2772*da0073e9SAndroid Build Coastguard Worker
2773*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2774*da0073e9SAndroid Build Coastguard Worker    def test_mark_static_nn_module_tensor(self):
2775*da0073e9SAndroid Build Coastguard Worker        # This test verifies that dynamo will mark
2776*da0073e9SAndroid Build Coastguard Worker        # the nn module tensor attributes as static
2777*da0073e9SAndroid Build Coastguard Worker        num_compiles = 0
2778*da0073e9SAndroid Build Coastguard Worker
2779*da0073e9SAndroid Build Coastguard Worker        def debug_compiler(gm, _):
2780*da0073e9SAndroid Build Coastguard Worker            nonlocal num_compiles
2781*da0073e9SAndroid Build Coastguard Worker            num_compiles += 1
2782*da0073e9SAndroid Build Coastguard Worker
2783*da0073e9SAndroid Build Coastguard Worker            input_nodes = [
2784*da0073e9SAndroid Build Coastguard Worker                n
2785*da0073e9SAndroid Build Coastguard Worker                for n in gm.graph.nodes
2786*da0073e9SAndroid Build Coastguard Worker                if n.op == "placeholder" and n.name == "l_mod_buf"
2787*da0073e9SAndroid Build Coastguard Worker            ]
2788*da0073e9SAndroid Build Coastguard Worker
2789*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(len(input_nodes), 0)
2790*da0073e9SAndroid Build Coastguard Worker            for input_node in input_nodes:
2791*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2792*da0073e9SAndroid Build Coastguard Worker                    input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
2793*da0073e9SAndroid Build Coastguard Worker                    "unguarded",
2794*da0073e9SAndroid Build Coastguard Worker                )
2795*da0073e9SAndroid Build Coastguard Worker
2796*da0073e9SAndroid Build Coastguard Worker            return gm
2797*da0073e9SAndroid Build Coastguard Worker
2798*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2799*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2800*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2801*da0073e9SAndroid Build Coastguard Worker                self.buf = torch.ones(2, 2)
2802*da0073e9SAndroid Build Coastguard Worker
2803*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2804*da0073e9SAndroid Build Coastguard Worker                return self.buf * x
2805*da0073e9SAndroid Build Coastguard Worker
2806*da0073e9SAndroid Build Coastguard Worker        mod = TestModule()
2807*da0073e9SAndroid Build Coastguard Worker
2808*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(backend=debug_compiler)
2809*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2810*da0073e9SAndroid Build Coastguard Worker            return x * mod(x)
2811*da0073e9SAndroid Build Coastguard Worker
2812*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2)
2813*da0073e9SAndroid Build Coastguard Worker        fn(inp)
2814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_compiles, 1)
2815*da0073e9SAndroid Build Coastguard Worker
2816*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2817*da0073e9SAndroid Build Coastguard Worker    @torch._inductor.config.patch("freezing", True)
2818*da0073e9SAndroid Build Coastguard Worker    @torch.no_grad()
2819*da0073e9SAndroid Build Coastguard Worker    def test_mark_static_with_freezing(self):
2820*da0073e9SAndroid Build Coastguard Worker        # This test verifies that dynamo will
2821*da0073e9SAndroid Build Coastguard Worker        # add buffers/params as attributes of the
2822*da0073e9SAndroid Build Coastguard Worker        # graph w/ guards if freezing is enabled
2823*da0073e9SAndroid Build Coastguard Worker        num_compiles = 0
2824*da0073e9SAndroid Build Coastguard Worker
2825*da0073e9SAndroid Build Coastguard Worker        def debug_compiler(gm, _):
2826*da0073e9SAndroid Build Coastguard Worker            nonlocal num_compiles
2827*da0073e9SAndroid Build Coastguard Worker            num_compiles += 1
2828*da0073e9SAndroid Build Coastguard Worker
2829*da0073e9SAndroid Build Coastguard Worker            input_nodes = [
2830*da0073e9SAndroid Build Coastguard Worker                n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
2831*da0073e9SAndroid Build Coastguard Worker            ]
2832*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(input_nodes), 0)
2833*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(list(gm.buffers())), 1)
2834*da0073e9SAndroid Build Coastguard Worker            return gm
2835*da0073e9SAndroid Build Coastguard Worker
2836*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2837*da0073e9SAndroid Build Coastguard Worker            def __init__(self, buf) -> None:
2838*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2839*da0073e9SAndroid Build Coastguard Worker                self.buf = torch.nn.Buffer(buf)
2840*da0073e9SAndroid Build Coastguard Worker
2841*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2842*da0073e9SAndroid Build Coastguard Worker                return self.buf * x
2843*da0073e9SAndroid Build Coastguard Worker
2844*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(backend=debug_compiler)
2845*da0073e9SAndroid Build Coastguard Worker        def fn(x, mod):
2846*da0073e9SAndroid Build Coastguard Worker            return mod(x)
2847*da0073e9SAndroid Build Coastguard Worker
2848*da0073e9SAndroid Build Coastguard Worker        buf = torch.ones(2, 2)
2849*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2)
2850*da0073e9SAndroid Build Coastguard Worker        mod = TestModule(buf)
2851*da0073e9SAndroid Build Coastguard Worker        fn(inp, mod)
2852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_compiles, 1)
2853*da0073e9SAndroid Build Coastguard Worker        mod.buf = torch.rand_like(buf)
2854*da0073e9SAndroid Build Coastguard Worker        fn(inp, mod)
2855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_compiles, 2)
2856*da0073e9SAndroid Build Coastguard Worker
2857*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "guard_nn_modules", True)
2858*da0073e9SAndroid Build Coastguard Worker    def test_guard_on_torch_nn_modules(self):
2859*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/110048
2860*da0073e9SAndroid Build Coastguard Worker
2861*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
2862*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2863*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2864*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(10, 10)
2865*da0073e9SAndroid Build Coastguard Worker                self.multiplier = 10
2866*da0073e9SAndroid Build Coastguard Worker
2867*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2868*da0073e9SAndroid Build Coastguard Worker                return self.linear(x) * self.multiplier
2869*da0073e9SAndroid Build Coastguard Worker
2870*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
2871*da0073e9SAndroid Build Coastguard Worker
2872*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2873*da0073e9SAndroid Build Coastguard Worker
2874*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt)
2875*da0073e9SAndroid Build Coastguard Worker        def generate(x, c):
2876*da0073e9SAndroid Build Coastguard Worker            return mod(x) + c
2877*da0073e9SAndroid Build Coastguard Worker
2878*da0073e9SAndroid Build Coastguard Worker        for _ in range(0, 10):
2879*da0073e9SAndroid Build Coastguard Worker            generate(torch.randn(10, 10), 0)
2880*da0073e9SAndroid Build Coastguard Worker            generate(torch.randn(10, 10), 1)
2881*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
2882*da0073e9SAndroid Build Coastguard Worker
2883*da0073e9SAndroid Build Coastguard Worker        # Ensure that modification in user module causes recompile
2884*da0073e9SAndroid Build Coastguard Worker        mod.multiplier = 11
2885*da0073e9SAndroid Build Coastguard Worker        generate(torch.randn(10, 10), 0)
2886*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
2887*da0073e9SAndroid Build Coastguard Worker
2888*da0073e9SAndroid Build Coastguard Worker    def test_setattr_on_compiled_module(self):
2889*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/114844
2890*da0073e9SAndroid Build Coastguard Worker
2891*da0073e9SAndroid Build Coastguard Worker        class ReplayMutation(torch.nn.Module):
2892*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inp_size, out_size, inner_size):
2893*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2894*da0073e9SAndroid Build Coastguard Worker                self.Linear1 = torch.nn.Linear(inp_size, inner_size)
2895*da0073e9SAndroid Build Coastguard Worker                self.Linear2 = torch.nn.Linear(inner_size, out_size)
2896*da0073e9SAndroid Build Coastguard Worker                self.x = None
2897*da0073e9SAndroid Build Coastguard Worker
2898*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
2899*da0073e9SAndroid Build Coastguard Worker                res = self.Linear1(inp)
2900*da0073e9SAndroid Build Coastguard Worker                self.x = res
2901*da0073e9SAndroid Build Coastguard Worker                return self.Linear2(res)
2902*da0073e9SAndroid Build Coastguard Worker
2903*da0073e9SAndroid Build Coastguard Worker        N, D_in, H, D_out, inner = 2, 2, 2, 2, 4
2904*da0073e9SAndroid Build Coastguard Worker        model = ReplayMutation(D_in, H, inner)
2905*da0073e9SAndroid Build Coastguard Worker        model2 = copy.deepcopy(model)
2906*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(N, D_in)
2907*da0073e9SAndroid Build Coastguard Worker
2908*da0073e9SAndroid Build Coastguard Worker        # Keep some intermediate value in model.x
2909*da0073e9SAndroid Build Coastguard Worker        model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
2910*da0073e9SAndroid Build Coastguard Worker        model(input)
2911*da0073e9SAndroid Build Coastguard Worker
2912*da0073e9SAndroid Build Coastguard Worker        compiled_model = torch.compile(model2, backend="eager")
2913*da0073e9SAndroid Build Coastguard Worker        compiled_model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
2914*da0073e9SAndroid Build Coastguard Worker        compiled_model(input)
2915*da0073e9SAndroid Build Coastguard Worker
2916*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model.x, compiled_model.x)
2917*da0073e9SAndroid Build Coastguard Worker
2918*da0073e9SAndroid Build Coastguard Worker    def test_globals_change_in_other_file(self):
2919*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
2920*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2921*da0073e9SAndroid Build Coastguard Worker            update_global()
2922*da0073e9SAndroid Build Coastguard Worker            a = test_functions.update_global(x)
2923*da0073e9SAndroid Build Coastguard Worker            # Ensure that the updated global values are read
2924*da0073e9SAndroid Build Coastguard Worker            return x * a * (_variable + _variable1 + test_functions._variable)
2925*da0073e9SAndroid Build Coastguard Worker
2926*da0073e9SAndroid Build Coastguard Worker        res = fn(torch.ones(10))
2927*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_variable, 1)
2928*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_variable1, 1)
2929*da0073e9SAndroid Build Coastguard Worker        # Ensure that the reconstructed bytecode updates the global value in the
2930*da0073e9SAndroid Build Coastguard Worker        # other file.
2931*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_functions._variable, 1)
2932*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, 3 * torch.ones(10))
2933*da0073e9SAndroid Build Coastguard Worker
2934*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2935*da0073e9SAndroid Build Coastguard Worker        "inductor" not in torch._dynamo.list_backends(),
2936*da0073e9SAndroid Build Coastguard Worker        "inductor backend is not available",
2937*da0073e9SAndroid Build Coastguard Worker    )
2938*da0073e9SAndroid Build Coastguard Worker    def test_save_and_load_inductor(self):
2939*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
2940*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend="inductor")
2941*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(10, 10)
2942*da0073e9SAndroid Build Coastguard Worker        opt_mod(inp)
2943*da0073e9SAndroid Build Coastguard Worker
2944*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmpdirname:
2945*da0073e9SAndroid Build Coastguard Worker            torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
2946*da0073e9SAndroid Build Coastguard Worker            loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
2947*da0073e9SAndroid Build Coastguard Worker        loaded_model(inp)
2948*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same_two_models(loaded_model, mod, [inp]))
2949*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
2950*da0073e9SAndroid Build Coastguard Worker
2951*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()  # force recompiles
2952*da0073e9SAndroid Build Coastguard Worker        torch._inductor.metrics.generated_kernel_count = 0
2953*da0073e9SAndroid Build Coastguard Worker        loaded_model(inp)
2954*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
2955*da0073e9SAndroid Build Coastguard Worker
2956*da0073e9SAndroid Build Coastguard Worker    def test_save_and_load_all_backends(self):
2957*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
2958*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(10, 10)
2959*da0073e9SAndroid Build Coastguard Worker        for backend in torch._dynamo.list_backends():
2960*da0073e9SAndroid Build Coastguard Worker            try:
2961*da0073e9SAndroid Build Coastguard Worker                opt_mod = torch.compile(mod, backend=backend)
2962*da0073e9SAndroid Build Coastguard Worker                with tempfile.TemporaryDirectory() as tmpdirname:
2963*da0073e9SAndroid Build Coastguard Worker                    torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
2964*da0073e9SAndroid Build Coastguard Worker                    loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
2965*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.reset()  # force recompiles
2966*da0073e9SAndroid Build Coastguard Worker                torch._inductor.metrics.generated_kernel_count = 0
2967*da0073e9SAndroid Build Coastguard Worker                opt_mod(inp)
2968*da0073e9SAndroid Build Coastguard Worker                opt_success = torch._inductor.metrics.generated_kernel_count == 0
2969*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.reset()  # force recompiles
2970*da0073e9SAndroid Build Coastguard Worker                torch._inductor.metrics.generated_kernel_count = 0
2971*da0073e9SAndroid Build Coastguard Worker                loaded_model(inp)
2972*da0073e9SAndroid Build Coastguard Worker                loaded_success = torch._inductor.metrics.generated_kernel_count == 0
2973*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(opt_success, loaded_success)
2974*da0073e9SAndroid Build Coastguard Worker            except torch._dynamo.exc.BackendCompilerFailed:
2975*da0073e9SAndroid Build Coastguard Worker                pass
2976*da0073e9SAndroid Build Coastguard Worker
2977*da0073e9SAndroid Build Coastguard Worker    def test_monkeypatching_forward(self):
2978*da0073e9SAndroid Build Coastguard Worker        class FakeModule(torch.nn.Module):
2979*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2980*da0073e9SAndroid Build Coastguard Worker                return torch.sin(x)
2981*da0073e9SAndroid Build Coastguard Worker
2982*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2983*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
2984*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2985*da0073e9SAndroid Build Coastguard Worker
2986*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2987*da0073e9SAndroid Build Coastguard Worker                return torch.cos(x)
2988*da0073e9SAndroid Build Coastguard Worker
2989*da0073e9SAndroid Build Coastguard Worker        def helper():
2990*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
2991*da0073e9SAndroid Build Coastguard Worker            mod = MyModule(3)
2992*da0073e9SAndroid Build Coastguard Worker
2993*da0073e9SAndroid Build Coastguard Worker            def fn(x):
2994*da0073e9SAndroid Build Coastguard Worker                return mod(x)
2995*da0073e9SAndroid Build Coastguard Worker
2996*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
2997*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(cnt)(fn)
2998*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10)
2999*da0073e9SAndroid Build Coastguard Worker
3000*da0073e9SAndroid Build Coastguard Worker            opt_fn(x)
3001*da0073e9SAndroid Build Coastguard Worker            opt_fn(x)
3002*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 1)
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker            # Monkeypatch forward
3005*da0073e9SAndroid Build Coastguard Worker            mod.forward = types.MethodType(FakeModule.forward, mod)
3006*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
3007*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
3008*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
3009*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 2)
3010*da0073e9SAndroid Build Coastguard Worker
3011*da0073e9SAndroid Build Coastguard Worker        helper()
3012*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True):
3013*da0073e9SAndroid Build Coastguard Worker            helper()
3014*da0073e9SAndroid Build Coastguard Worker
3015*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_nn_module_dynamic(self):
3016*da0073e9SAndroid Build Coastguard Worker        class Conv2d(torch.nn.Conv2d):
3017*da0073e9SAndroid Build Coastguard Worker            def __init__(self, *args, **kwargs):
3018*da0073e9SAndroid Build Coastguard Worker                super().__init__(*args, **kwargs)
3019*da0073e9SAndroid Build Coastguard Worker
3020*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3021*da0073e9SAndroid Build Coastguard Worker                x = torch.nn.functional.conv2d(
3022*da0073e9SAndroid Build Coastguard Worker                    x,
3023*da0073e9SAndroid Build Coastguard Worker                    self.weight,
3024*da0073e9SAndroid Build Coastguard Worker                    self.bias,
3025*da0073e9SAndroid Build Coastguard Worker                    self.stride,
3026*da0073e9SAndroid Build Coastguard Worker                    self.padding,
3027*da0073e9SAndroid Build Coastguard Worker                    self.dilation,
3028*da0073e9SAndroid Build Coastguard Worker                    self.groups,
3029*da0073e9SAndroid Build Coastguard Worker                )
3030*da0073e9SAndroid Build Coastguard Worker                return x
3031*da0073e9SAndroid Build Coastguard Worker
3032*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3033*da0073e9SAndroid Build Coastguard Worker        mod1 = Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
3034*da0073e9SAndroid Build Coastguard Worker        mod2 = Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
3035*da0073e9SAndroid Build Coastguard Worker        mod3 = Conv2d(64, 64, kernel_size=(2, 2), stride=(3, 3))
3036*da0073e9SAndroid Build Coastguard Worker
3037*da0073e9SAndroid Build Coastguard Worker        opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True)
3038*da0073e9SAndroid Build Coastguard Worker        opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True)
3039*da0073e9SAndroid Build Coastguard Worker        opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True)
3040*da0073e9SAndroid Build Coastguard Worker
3041*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 64, 64, 64)
3042*da0073e9SAndroid Build Coastguard Worker        opt_mod1(x)
3043*da0073e9SAndroid Build Coastguard Worker        opt_mod2(x)
3044*da0073e9SAndroid Build Coastguard Worker        opt_mod3(x)
3045*da0073e9SAndroid Build Coastguard Worker
3046*da0073e9SAndroid Build Coastguard Worker        # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints.
3047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)
3048*da0073e9SAndroid Build Coastguard Worker
3049*da0073e9SAndroid Build Coastguard Worker
3050*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
3051*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
3052*da0073e9SAndroid Build Coastguard Worker
3053*da0073e9SAndroid Build Coastguard Worker    run_tests()
3054