xref: /aosp_15_r20/external/pytorch/test/jit/test_freezing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
10*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
11*da0073e9SAndroid Build Coastguard Workerfrom torch.jit._recursive import wrap_cpp_module
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import skipIfNoFBGEMM
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantized import override_quantized_engine
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
17*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
18*da0073e9SAndroid Build Coastguard Worker    skipCUDAMemoryLeakCheckIf,
19*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
20*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
23*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import mkldnn as mkldnn_utils
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workertry:
27*da0073e9SAndroid Build Coastguard Worker    import torchvision
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
30*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
31*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
32*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
35*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
36*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
37*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
38*da0073e9SAndroid Build Coastguard Worker        "instead."
39*da0073e9SAndroid Build Coastguard Worker    )
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard WorkerTEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workerdef removeExceptions(graph):
45*da0073e9SAndroid Build Coastguard Worker    for n in graph.findAllNodes("prim::RaiseException"):
46*da0073e9SAndroid Build Coastguard Worker        n.destroy()
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerclass TestFreezing(JitTestCase):
50*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module(self):
51*da0073e9SAndroid Build Coastguard Worker        class M(nn.Module):
52*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
53*da0073e9SAndroid Build Coastguard Worker                super().__init__()
54*da0073e9SAndroid Build Coastguard Worker                self.a = 1  # folded
55*da0073e9SAndroid Build Coastguard Worker                self.b = 1.2  # folded
56*da0073e9SAndroid Build Coastguard Worker                self.c = "hello"  # folded
57*da0073e9SAndroid Build Coastguard Worker                self.c2 = "hi\xA1"  # not folded
58*da0073e9SAndroid Build Coastguard Worker                self.d = [1, 1]  # folded
59*da0073e9SAndroid Build Coastguard Worker                self.e = [1.0, 1.1]  # folded
60*da0073e9SAndroid Build Coastguard Worker                self.f = ["hello", "world"]  # folded
61*da0073e9SAndroid Build Coastguard Worker                self.f2 = [(1, "Over \u0e55\u0e57 57")]
62*da0073e9SAndroid Build Coastguard Worker                self.g = (
63*da0073e9SAndroid Build Coastguard Worker                    [1, 2],
64*da0073e9SAndroid Build Coastguard Worker                    3.2,
65*da0073e9SAndroid Build Coastguard Worker                    "4.4",
66*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([5.5], requires_grad=True),
67*da0073e9SAndroid Build Coastguard Worker                )  # folded
68*da0073e9SAndroid Build Coastguard Worker                self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]}
69*da0073e9SAndroid Build Coastguard Worker                self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]}
70*da0073e9SAndroid Build Coastguard Worker                self.t = torch.tensor([1.2, 2.4], requires_grad=True)  # folded
71*da0073e9SAndroid Build Coastguard Worker                self.ts = [
72*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([1.0, 2.0], requires_grad=True),
73*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([3.0, 4.0], requires_grad=True),
74*da0073e9SAndroid Build Coastguard Worker                ]  # folded
75*da0073e9SAndroid Build Coastguard Worker                self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]]
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
78*da0073e9SAndroid Build Coastguard Worker                return (
79*da0073e9SAndroid Build Coastguard Worker                    str(self.a)
80*da0073e9SAndroid Build Coastguard Worker                    + str(self.b)
81*da0073e9SAndroid Build Coastguard Worker                    + self.c
82*da0073e9SAndroid Build Coastguard Worker                    + self.c2
83*da0073e9SAndroid Build Coastguard Worker                    + str(self.d)
84*da0073e9SAndroid Build Coastguard Worker                    + str(self.e)
85*da0073e9SAndroid Build Coastguard Worker                    + str(self.f)
86*da0073e9SAndroid Build Coastguard Worker                    + str(self.f2)
87*da0073e9SAndroid Build Coastguard Worker                    + str(self.g)
88*da0073e9SAndroid Build Coastguard Worker                    + str(self.h)
89*da0073e9SAndroid Build Coastguard Worker                    + str(self.h2)
90*da0073e9SAndroid Build Coastguard Worker                    + str(self.t)
91*da0073e9SAndroid Build Coastguard Worker                    + str(self.ts)
92*da0073e9SAndroid Build Coastguard Worker                    + str(self.tt)
93*da0073e9SAndroid Build Coastguard Worker                )
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(M())
96*da0073e9SAndroid Build Coastguard Worker        m.eval()
97*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
98*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
99*da0073e9SAndroid Build Coastguard Worker        m._c = torch._C._freeze_module(m._c)
100*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
101*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(m._c, buffer)
102*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
103*da0073e9SAndroid Build Coastguard Worker        m2 = torch.jit.load(buffer)
104*da0073e9SAndroid Build Coastguard Worker        # Check if frozen module looks as below:
105*da0073e9SAndroid Build Coastguard Worker        # module m {
106*da0073e9SAndroid Build Coastguard Worker        #   attributes {
107*da0073e9SAndroid Build Coastguard Worker        #     tt = ...
108*da0073e9SAndroid Build Coastguard Worker        #   }
109*da0073e9SAndroid Build Coastguard Worker        #   ...
110*da0073e9SAndroid Build Coastguard Worker        # }
111*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("a"))
112*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("b"))
113*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("c"))
114*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("c2"))
115*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("d"))
116*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("e"))
117*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("f"))
118*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("f2"))
119*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("g"))
120*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("h"))
121*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("h2"))
122*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("t"))
123*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("ts"))
124*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2._c.hasattr("tt"))
125*da0073e9SAndroid Build Coastguard Worker        output_f = m2.forward(input)
126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_submodule(self):
129*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
130*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
131*da0073e9SAndroid Build Coastguard Worker                super().__init__()
132*da0073e9SAndroid Build Coastguard Worker                self.a = 11
133*da0073e9SAndroid Build Coastguard Worker                self.b = 2
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
136*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        class SubModule2(nn.Module):
139*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
140*da0073e9SAndroid Build Coastguard Worker                super().__init__()
141*da0073e9SAndroid Build Coastguard Worker                self.a = 12
142*da0073e9SAndroid Build Coastguard Worker                self.b = 2
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
145*da0073e9SAndroid Build Coastguard Worker                self.b = 30
146*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
149*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
150*da0073e9SAndroid Build Coastguard Worker                super().__init__()
151*da0073e9SAndroid Build Coastguard Worker                self.sub1 = SubModule()
152*da0073e9SAndroid Build Coastguard Worker                self.sub2 = SubModule2()
153*da0073e9SAndroid Build Coastguard Worker                self.a = 3
154*da0073e9SAndroid Build Coastguard Worker                self.b = 4
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
157*da0073e9SAndroid Build Coastguard Worker                self.b = 20
158*da0073e9SAndroid Build Coastguard Worker                return self.sub1(x) + self.a + self.b + self.sub2(x)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
161*da0073e9SAndroid Build Coastguard Worker        m.eval()
162*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
163*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
164*da0073e9SAndroid Build Coastguard Worker        mf = torch.jit.freeze(m)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        # Check if frozen module looks as below:
167*da0073e9SAndroid Build Coastguard Worker        # module m {
168*da0073e9SAndroid Build Coastguard Worker        #   attributes {
169*da0073e9SAndroid Build Coastguard Worker        #     sub2 = ...
170*da0073e9SAndroid Build Coastguard Worker        #      b =
171*da0073e9SAndroid Build Coastguard Worker        #   }
172*da0073e9SAndroid Build Coastguard Worker        #   ...
173*da0073e9SAndroid Build Coastguard Worker        #   submodule {
174*da0073e9SAndroid Build Coastguard Worker        #     module m {
175*da0073e9SAndroid Build Coastguard Worker        #       attributes {
176*da0073e9SAndroid Build Coastguard Worker        #         sub2 = ...
177*da0073e9SAndroid Build Coastguard Worker        #         b =
178*da0073e9SAndroid Build Coastguard Worker        #       }
179*da0073e9SAndroid Build Coastguard Worker        #       ...
180*da0073e9SAndroid Build Coastguard Worker        #     }
181*da0073e9SAndroid Build Coastguard Worker        #   }
182*da0073e9SAndroid Build Coastguard Worker        # }
183*da0073e9SAndroid Build Coastguard Worker        mf = mf._c
184*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("sub1"))
185*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("a"))
186*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("b"))
187*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub2"))
188*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub2.hasattr("b"))  # verify b is preserved in sub2
189*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.sub2.hasattr("a"))  # verify a is removed in sub2
190*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_fork(self):
194*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
195*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
196*da0073e9SAndroid Build Coastguard Worker                super().__init__()
197*da0073e9SAndroid Build Coastguard Worker                self.a = torch.ones(20, 20)
198*da0073e9SAndroid Build Coastguard Worker                self.b = torch.ones(20, 20)
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
201*da0073e9SAndroid Build Coastguard Worker                return self.a * self.b + x
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
204*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
205*da0073e9SAndroid Build Coastguard Worker                super().__init__()
206*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
209*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit._fork(self.sub.forward, x)
210*da0073e9SAndroid Build Coastguard Worker                y_hat = self.sub(x)
211*da0073e9SAndroid Build Coastguard Worker                y = torch.jit._wait(fut)
212*da0073e9SAndroid Build Coastguard Worker                return y_hat + y
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
215*da0073e9SAndroid Build Coastguard Worker        m.eval()
216*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(20, 20)
217*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
218*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(m._c)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        # Check if frozen module looks as below:
221*da0073e9SAndroid Build Coastguard Worker        # module m {
222*da0073e9SAndroid Build Coastguard Worker        #   attributes {
223*da0073e9SAndroid Build Coastguard Worker        #   }
224*da0073e9SAndroid Build Coastguard Worker        #   ...
225*da0073e9SAndroid Build Coastguard Worker        #   submodule {
226*da0073e9SAndroid Build Coastguard Worker        #   }
227*da0073e9SAndroid Build Coastguard Worker        # }
228*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("a"))
229*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("b"))
230*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_nested_fork(self):
234*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
235*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
236*da0073e9SAndroid Build Coastguard Worker                super().__init__()
237*da0073e9SAndroid Build Coastguard Worker                self.a = torch.ones(20, 20)
238*da0073e9SAndroid Build Coastguard Worker                self.b = torch.ones(20, 20)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
241*da0073e9SAndroid Build Coastguard Worker                return self.a * self.b + x
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        class SubModule2(nn.Module):
244*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
245*da0073e9SAndroid Build Coastguard Worker                super().__init__()
246*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
247*da0073e9SAndroid Build Coastguard Worker                self.c = torch.ones(20, 20)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
250*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit._fork(self.sub.forward, x)
251*da0073e9SAndroid Build Coastguard Worker                y_hat = self.sub(x)
252*da0073e9SAndroid Build Coastguard Worker                y = torch.jit._wait(fut)
253*da0073e9SAndroid Build Coastguard Worker                return y_hat + y + self.c
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
256*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
257*da0073e9SAndroid Build Coastguard Worker                super().__init__()
258*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule2()
259*da0073e9SAndroid Build Coastguard Worker                self.d = 1
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
262*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit._fork(self.sub.forward, x)
263*da0073e9SAndroid Build Coastguard Worker                y_hat = self.sub(x)
264*da0073e9SAndroid Build Coastguard Worker                y = torch.jit._wait(fut)
265*da0073e9SAndroid Build Coastguard Worker                self.d = 2
266*da0073e9SAndroid Build Coastguard Worker                return y_hat * y + self.d
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
269*da0073e9SAndroid Build Coastguard Worker        m.eval()
270*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(20, 20)
271*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
272*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(m._c)
273*da0073e9SAndroid Build Coastguard Worker        # Check if frozen module looks as below:
274*da0073e9SAndroid Build Coastguard Worker        # module m {
275*da0073e9SAndroid Build Coastguard Worker        #   attributes {
276*da0073e9SAndroid Build Coastguard Worker        #   }
277*da0073e9SAndroid Build Coastguard Worker        #   ...
278*da0073e9SAndroid Build Coastguard Worker        #   submodule {
279*da0073e9SAndroid Build Coastguard Worker        #   }
280*da0073e9SAndroid Build Coastguard Worker        # }
281*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("a"))
282*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("b"))
283*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("c"))
284*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("d"))
285*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_fork2(self):
289*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
290*da0073e9SAndroid Build Coastguard Worker        def foo(x):
291*da0073e9SAndroid Build Coastguard Worker            return x * 2
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
294*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
295*da0073e9SAndroid Build Coastguard Worker                super().__init__()
296*da0073e9SAndroid Build Coastguard Worker                self.a = torch.ones(20, 20)
297*da0073e9SAndroid Build Coastguard Worker                self.b = torch.ones(20, 20)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
300*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit._fork(foo, self.a)
301*da0073e9SAndroid Build Coastguard Worker                y_hat = foo(self.b)
302*da0073e9SAndroid Build Coastguard Worker                y = torch.jit._wait(fut)
303*da0073e9SAndroid Build Coastguard Worker                return y_hat + y
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
306*da0073e9SAndroid Build Coastguard Worker        m.eval()
307*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
308*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
309*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(m._c)
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        # Check if frozen module looks as below:
312*da0073e9SAndroid Build Coastguard Worker        # module m {
313*da0073e9SAndroid Build Coastguard Worker        #   attributes {
314*da0073e9SAndroid Build Coastguard Worker        #     self.a = ...
315*da0073e9SAndroid Build Coastguard Worker        #     self.b = ..
316*da0073e9SAndroid Build Coastguard Worker        #   }
317*da0073e9SAndroid Build Coastguard Worker        #   ...
318*da0073e9SAndroid Build Coastguard Worker        #   submodule {
319*da0073e9SAndroid Build Coastguard Worker        #   }
320*da0073e9SAndroid Build Coastguard Worker        # }
321*da0073e9SAndroid Build Coastguard Worker        # TODO:  Although there are no mutation, the alias analysis
322*da0073e9SAndroid Build Coastguard Worker        # conservatively assumes there is a mutation because attributes are
323*da0073e9SAndroid Build Coastguard Worker        # passed to fork subgraph. both 'a' and 'b' are preserved.
324*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("a"))
325*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("b"))
326*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_fork_calling_module_method(self):
330*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
331*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
332*da0073e9SAndroid Build Coastguard Worker            return x * y
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
335*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
336*da0073e9SAndroid Build Coastguard Worker                super().__init__()
337*da0073e9SAndroid Build Coastguard Worker                self.a = torch.ones(20, 20)
338*da0073e9SAndroid Build Coastguard Worker                self.b = torch.ones(20, 20)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
341*da0073e9SAndroid Build Coastguard Worker            def foo(self, x):
342*da0073e9SAndroid Build Coastguard Worker                return x * self.a
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
345*da0073e9SAndroid Build Coastguard Worker            def bar(self, x):
346*da0073e9SAndroid Build Coastguard Worker                return x * self.b
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
349*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit._fork(self.foo, self.b)
350*da0073e9SAndroid Build Coastguard Worker                y_hat = self.bar(self.a)
351*da0073e9SAndroid Build Coastguard Worker                y = torch.jit._wait(fut)
352*da0073e9SAndroid Build Coastguard Worker                return y_hat + y
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
355*da0073e9SAndroid Build Coastguard Worker        m.eval()
356*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
357*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
358*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(m._c)
359*da0073e9SAndroid Build Coastguard Worker        # Check if frozen module looks as below:
360*da0073e9SAndroid Build Coastguard Worker        # module m {
361*da0073e9SAndroid Build Coastguard Worker        #   attributes {
362*da0073e9SAndroid Build Coastguard Worker        #     self.b = ..
363*da0073e9SAndroid Build Coastguard Worker        #   }
364*da0073e9SAndroid Build Coastguard Worker        #   ...
365*da0073e9SAndroid Build Coastguard Worker        # TODO:  Although there are no mutation, the alias analysis
366*da0073e9SAndroid Build Coastguard Worker        # conservatively assumes there is a mutation because attributes are
367*da0073e9SAndroid Build Coastguard Worker        # passed to fork subgraph. 'b' is preserved.
368*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("a"))
369*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("b"))
370*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_sharedclasstype(self):
374*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
375*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
376*da0073e9SAndroid Build Coastguard Worker                super().__init__()
377*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.1])
378*da0073e9SAndroid Build Coastguard Worker                self.b = torch.tensor([2.2])
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
381*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
384*da0073e9SAndroid Build Coastguard Worker            def modify_a(self, x):
385*da0073e9SAndroid Build Coastguard Worker                self.a[0] += 10
386*da0073e9SAndroid Build Coastguard Worker                return self.b
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
389*da0073e9SAndroid Build Coastguard Worker            def modify_b(self, x):
390*da0073e9SAndroid Build Coastguard Worker                self.b[0] += 20
391*da0073e9SAndroid Build Coastguard Worker                return self.a
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        class SubModule2(nn.Module):
394*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
395*da0073e9SAndroid Build Coastguard Worker                super().__init__()
396*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
397*da0073e9SAndroid Build Coastguard Worker                self.b = torch.tensor([3.3])
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
400*da0073e9SAndroid Build Coastguard Worker                y = self.sub.modify_b(x)
401*da0073e9SAndroid Build Coastguard Worker                return y + self.b
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
404*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
405*da0073e9SAndroid Build Coastguard Worker                super().__init__()
406*da0073e9SAndroid Build Coastguard Worker                self.sub1 = SubModule()  # sub1 and sub2.sub shared same class type.
407*da0073e9SAndroid Build Coastguard Worker                self.sub2 = SubModule2()
408*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([4.4])
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
411*da0073e9SAndroid Build Coastguard Worker                z = self.sub1.modify_a(x)
412*da0073e9SAndroid Build Coastguard Worker                return self.sub2(x) + z + self.a
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
415*da0073e9SAndroid Build Coastguard Worker        m.eval()
416*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
417*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
418*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(m._c)
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        # Checking if  Frozen module looks as  below
421*da0073e9SAndroid Build Coastguard Worker        # module mf {
422*da0073e9SAndroid Build Coastguard Worker        #   attributes {
423*da0073e9SAndroid Build Coastguard Worker        #     sub1 = ...
424*da0073e9SAndroid Build Coastguard Worker        #     sub2 = ...
425*da0073e9SAndroid Build Coastguard Worker        #   }
426*da0073e9SAndroid Build Coastguard Worker        #   ...
427*da0073e9SAndroid Build Coastguard Worker        #   submodules {
428*da0073e9SAndroid Build Coastguard Worker        #     module sub1 {
429*da0073e9SAndroid Build Coastguard Worker        #       attributes {
430*da0073e9SAndroid Build Coastguard Worker        #         a = ...
431*da0073e9SAndroid Build Coastguard Worker        #         b = ...
432*da0073e9SAndroid Build Coastguard Worker        #       }
433*da0073e9SAndroid Build Coastguard Worker        #       ...
434*da0073e9SAndroid Build Coastguard Worker        #     }
435*da0073e9SAndroid Build Coastguard Worker        #     module sub2 {
436*da0073e9SAndroid Build Coastguard Worker        #       attributes {
437*da0073e9SAndroid Build Coastguard Worker        #         sub = ...
438*da0073e9SAndroid Build Coastguard Worker        #       }
439*da0073e9SAndroid Build Coastguard Worker        #       ...
440*da0073e9SAndroid Build Coastguard Worker        #       submodule {
441*da0073e9SAndroid Build Coastguard Worker        #         module sub {
442*da0073e9SAndroid Build Coastguard Worker        #           attributes {
443*da0073e9SAndroid Build Coastguard Worker        #             a = ...
444*da0073e9SAndroid Build Coastguard Worker        #             b = ...
445*da0073e9SAndroid Build Coastguard Worker        #           }
446*da0073e9SAndroid Build Coastguard Worker        #           ...
447*da0073e9SAndroid Build Coastguard Worker        #         }
448*da0073e9SAndroid Build Coastguard Worker        #       }
449*da0073e9SAndroid Build Coastguard Worker        #     }
450*da0073e9SAndroid Build Coastguard Worker        #   }
451*da0073e9SAndroid Build Coastguard Worker        # }
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub1"))
454*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("a"))
455*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("b"))
456*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("a"))
457*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub2"))
458*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub2.hasattr("sub"))
459*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.sub2.hasattr("b"))
460*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub2.sub.hasattr("a"))
461*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub2.sub.hasattr("b"))
462*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_nestedaliasing(self):
466*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
467*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
468*da0073e9SAndroid Build Coastguard Worker                super().__init__()
469*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.1])
470*da0073e9SAndroid Build Coastguard Worker                self.b = torch.tensor([2.2])
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
473*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
476*da0073e9SAndroid Build Coastguard Worker            def modify_a(self, x):
477*da0073e9SAndroid Build Coastguard Worker                self.a[0] = 10
478*da0073e9SAndroid Build Coastguard Worker                return self.b
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
481*da0073e9SAndroid Build Coastguard Worker            def modify_b(self, x):
482*da0073e9SAndroid Build Coastguard Worker                self.b[0] = 20
483*da0073e9SAndroid Build Coastguard Worker                return self.a
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker        Sub = SubModule()
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        class SubModule2(nn.Module):
488*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
489*da0073e9SAndroid Build Coastguard Worker                super().__init__()
490*da0073e9SAndroid Build Coastguard Worker                self.sub = Sub  # aliasing
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
493*da0073e9SAndroid Build Coastguard Worker                return self.sub.a
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
496*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
497*da0073e9SAndroid Build Coastguard Worker                super().__init__()
498*da0073e9SAndroid Build Coastguard Worker                self.sub1 = Sub  # aliasing
499*da0073e9SAndroid Build Coastguard Worker                self.sub2 = SubModule2()
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
502*da0073e9SAndroid Build Coastguard Worker                z = self.sub1.modify_a(x)
503*da0073e9SAndroid Build Coastguard Worker                return self.sub2(x) + z
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
506*da0073e9SAndroid Build Coastguard Worker        m.eval()
507*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(m._c)
508*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub1"))
509*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("a"))
510*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.sub1.hasattr("b"))
511*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub2"))
512*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub2.hasattr("sub"))
513*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
514*da0073e9SAndroid Build Coastguard Worker            mf.sub2.sub.hasattr("a")
515*da0073e9SAndroid Build Coastguard Worker        )  # Freezing detects that self.sub2.sub.a and self.sub1.a are alias
516*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.sub2.sub.hasattr("b"))
517*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
518*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
519*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    # FIXME: JIT is not honoring aliasing. 'Sub' module is copied. As a result
523*da0073e9SAndroid Build Coastguard Worker    # Eager and Script modules produce different output.
524*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_nestedaliasingscalar(self):
525*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
526*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
527*da0073e9SAndroid Build Coastguard Worker                super().__init__()
528*da0073e9SAndroid Build Coastguard Worker                self.a = 1.1
529*da0073e9SAndroid Build Coastguard Worker                self.b = 2.2
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
532*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
535*da0073e9SAndroid Build Coastguard Worker            def modify_a(self, x):
536*da0073e9SAndroid Build Coastguard Worker                self.a = 10.0
537*da0073e9SAndroid Build Coastguard Worker                return self.b
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
540*da0073e9SAndroid Build Coastguard Worker            def modify_b(self, x):
541*da0073e9SAndroid Build Coastguard Worker                self.b = 20.0
542*da0073e9SAndroid Build Coastguard Worker                return self.a
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        Sub = SubModule()
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        class SubModule2(nn.Module):
547*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
548*da0073e9SAndroid Build Coastguard Worker                super().__init__()
549*da0073e9SAndroid Build Coastguard Worker                self.sub = Sub  # aliasing
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
552*da0073e9SAndroid Build Coastguard Worker                return self.sub.a
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
555*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
556*da0073e9SAndroid Build Coastguard Worker                super().__init__()
557*da0073e9SAndroid Build Coastguard Worker                self.sub1 = Sub  # aliasing
558*da0073e9SAndroid Build Coastguard Worker                self.sub2 = SubModule2()
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
561*da0073e9SAndroid Build Coastguard Worker                z = self.sub1.modify_a(x)
562*da0073e9SAndroid Build Coastguard Worker                return self.sub2(x) + z
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
565*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(m)
566*da0073e9SAndroid Build Coastguard Worker        ms.eval()
567*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(ms._c)
568*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub1"))
569*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("a"))
570*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.sub1.hasattr("b"))
571*da0073e9SAndroid Build Coastguard Worker        # sub2 is fully folded becasue self.sub1 and self.sub2.sub are not alias (Scripting bug)
572*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("sub2"))
573*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
574*da0073e9SAndroid Build Coastguard Worker        output = m.forward(input)
575*da0073e9SAndroid Build Coastguard Worker        output_s = ms.forward(input)
576*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
577*da0073e9SAndroid Build Coastguard Worker        # Should be equal
578*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(output, output_s)
579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_preserve_sub_module(self):
582*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
583*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
584*da0073e9SAndroid Build Coastguard Worker                super().__init__()
585*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.1])
586*da0073e9SAndroid Build Coastguard Worker                self.b = 2.2
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
589*da0073e9SAndroid Build Coastguard Worker                return self.a
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
592*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
593*da0073e9SAndroid Build Coastguard Worker                super().__init__()
594*da0073e9SAndroid Build Coastguard Worker                self.sub1 = SubModule()  # aliasing
595*da0073e9SAndroid Build Coastguard Worker                self.sub2 = SubModule()
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
598*da0073e9SAndroid Build Coastguard Worker                return self.sub2(x) + self.sub1(x)
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
601*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(m)
602*da0073e9SAndroid Build Coastguard Worker        ms.eval()
603*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(ms._c, ["sub1"])
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        # Test that 'sub1' is preserved entirely and 'sub2' is completely folded
606*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub1"))
607*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("a"))
608*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("b"))
609*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("sub2"))
610*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
611*da0073e9SAndroid Build Coastguard Worker        output_s = ms.forward(input)
612*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_preserve_sub_module_and_mutation(self):
616*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
617*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
618*da0073e9SAndroid Build Coastguard Worker                super().__init__()
619*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.1])
620*da0073e9SAndroid Build Coastguard Worker                self.b = 2.2
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
623*da0073e9SAndroid Build Coastguard Worker                self.a[0] = 3.3
624*da0073e9SAndroid Build Coastguard Worker                return self.a
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
627*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
628*da0073e9SAndroid Build Coastguard Worker                super().__init__()
629*da0073e9SAndroid Build Coastguard Worker                self.sub1 = SubModule()  # aliasing
630*da0073e9SAndroid Build Coastguard Worker                self.sub2 = SubModule()
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
633*da0073e9SAndroid Build Coastguard Worker                return self.sub2(x) + self.sub1(x)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
636*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(m)
637*da0073e9SAndroid Build Coastguard Worker        ms.eval()
638*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(ms._c, ["sub1"])
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker        # Test that be both sub1 and sub1 are preserved and 'b' is preserved
641*da0073e9SAndroid Build Coastguard Worker        # even if it is not used. To fulfill user request to preserve 'sub1'
642*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub1"))
643*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("a"))
644*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub1.hasattr("b"))
645*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("sub2"))
646*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub2.hasattr("a"))
647*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.sub2.hasattr("b"))
648*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
649*da0073e9SAndroid Build Coastguard Worker        output_s = ms.forward(input)
650*da0073e9SAndroid Build Coastguard Worker        output_f = mf.forward(input)
651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_helperfunction(self):
654*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
655*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
656*da0073e9SAndroid Build Coastguard Worker                super().__init__()
657*da0073e9SAndroid Build Coastguard Worker                self.a = 11
658*da0073e9SAndroid Build Coastguard Worker                self.b = 2
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
661*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
664*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
665*da0073e9SAndroid Build Coastguard Worker                super().__init__()
666*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
667*da0073e9SAndroid Build Coastguard Worker                self.a = 3
668*da0073e9SAndroid Build Coastguard Worker                self.b = 4
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
671*da0073e9SAndroid Build Coastguard Worker                self.b = 20
672*da0073e9SAndroid Build Coastguard Worker                return self._forward(x) + self.a + self.b
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker            def _forward(self, x):
675*da0073e9SAndroid Build Coastguard Worker                return self.sub(x)
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
678*da0073e9SAndroid Build Coastguard Worker        m.eval()
679*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
680*da0073e9SAndroid Build Coastguard Worker        mf = torch._C._freeze_module(m._c)
681*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("sub"))
682*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mf.hasattr("a"))
683*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mf.hasattr("b"))
684*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
685*da0073e9SAndroid Build Coastguard Worker            AttributeError, "TestModule (.*) does not have a field with name '_forward'"
686*da0073e9SAndroid Build Coastguard Worker        ):
687*da0073e9SAndroid Build Coastguard Worker            mf._forward(x)  # noqa: F821
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_inplace_mutable(self):
690*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(torch.jit.ScriptModule):
691*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
692*da0073e9SAndroid Build Coastguard Worker                super().__init__()
693*da0073e9SAndroid Build Coastguard Worker                self.a = [11, 22]
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
696*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
697*da0073e9SAndroid Build Coastguard Worker                for i in range(3):
698*da0073e9SAndroid Build Coastguard Worker                    self.a.append(i)
699*da0073e9SAndroid Build Coastguard Worker                return self.a
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
702*da0073e9SAndroid Build Coastguard Worker        m.eval()
703*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m._c)
704*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("a"))
705*da0073e9SAndroid Build Coastguard Worker        m.forward(torch.tensor([3]))
706*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(torch.tensor([5]))
707*da0073e9SAndroid Build Coastguard Worker        expected = [11, 22, 0, 1, 2, 0, 1, 2]
708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    # Mutable attributes
711*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_mutable_list(self):
712*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
713*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
714*da0073e9SAndroid Build Coastguard Worker                super().__init__()
715*da0073e9SAndroid Build Coastguard Worker                self.a = [1, 2]
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
718*da0073e9SAndroid Build Coastguard Worker                return self.a
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
721*da0073e9SAndroid Build Coastguard Worker        m.eval()
722*da0073e9SAndroid Build Coastguard Worker        m.a.append(3)
723*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
724*da0073e9SAndroid Build Coastguard Worker        v = m_s.a
725*da0073e9SAndroid Build Coastguard Worker        v.append(4)
726*da0073e9SAndroid Build Coastguard Worker        m_s.a = v
727*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
728*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
729*da0073e9SAndroid Build Coastguard Worker        # Post-freezing mutating m_s.a  does not affect m_f (m_f has its own copy).
730*da0073e9SAndroid Build Coastguard Worker        v = m_s.a
731*da0073e9SAndroid Build Coastguard Worker        v.append(5)
732*da0073e9SAndroid Build Coastguard Worker        m_s.a = v
733*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m_f.hasattr("a"))
734*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(torch.tensor([5]))
735*da0073e9SAndroid Build Coastguard Worker        expected = [1, 2, 3, 4]
736*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_mutable_dict(self):
739*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
740*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
741*da0073e9SAndroid Build Coastguard Worker                super().__init__()
742*da0073e9SAndroid Build Coastguard Worker                self.a = {"layer": "4"}
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
745*da0073e9SAndroid Build Coastguard Worker                return self.a
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
748*da0073e9SAndroid Build Coastguard Worker            def modify_a(self, x):
749*da0073e9SAndroid Build Coastguard Worker                self.a["layer"] = self.a["layer"] + "1"
750*da0073e9SAndroid Build Coastguard Worker                return self.a
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
753*da0073e9SAndroid Build Coastguard Worker        m.eval()
754*da0073e9SAndroid Build Coastguard Worker        m.a["layer2"] = "3"
755*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
756*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(5)
757*da0073e9SAndroid Build Coastguard Worker        m_s.modify_a(t)
758*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
759*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
760*da0073e9SAndroid Build Coastguard Worker        m.a["layer2"] += "2"
761*da0073e9SAndroid Build Coastguard Worker        m_s.modify_a(t)
762*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m_f.hasattr("a"))
763*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(t)
764*da0073e9SAndroid Build Coastguard Worker        expected = {"layer": "411", "layer2": "3"}
765*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_mutable_tensor(self):
768*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
769*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
770*da0073e9SAndroid Build Coastguard Worker                super().__init__()
771*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.0, 2.0, 3.0])
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
774*da0073e9SAndroid Build Coastguard Worker                return self.a
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
777*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
778*da0073e9SAndroid Build Coastguard Worker        m_s.a[1] += 3.0
779*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
780*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
781*da0073e9SAndroid Build Coastguard Worker        # Post-freezing tensor attribute mutations affect m_f.
782*da0073e9SAndroid Build Coastguard Worker        # FIXME: deep copy all folded attributes so that m_f has full ownership.
783*da0073e9SAndroid Build Coastguard Worker        m_s.a[0] += 5.0
784*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m_f.hasattr("a"))
785*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(torch.tensor([5]))
786*da0073e9SAndroid Build Coastguard Worker        expected = [6.0, 5.0, 3.0]
787*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_tuple(self):
790*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
791*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
792*da0073e9SAndroid Build Coastguard Worker                super().__init__()
793*da0073e9SAndroid Build Coastguard Worker                self.a = (torch.tensor([1, 2, 3, 4, 5, 6]), "hi")
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
796*da0073e9SAndroid Build Coastguard Worker                if x[0] == 2.0:
797*da0073e9SAndroid Build Coastguard Worker                    self.a[0][0] = 10
798*da0073e9SAndroid Build Coastguard Worker                return self.a[0].sum()
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
801*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
802*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
803*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([2.0])
804*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
805*da0073e9SAndroid Build Coastguard Worker        m_s.a[0][0] = 1
806*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
807*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m_f.hasattr("a"))
808*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
809*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_tensor(self):
812*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
813*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
814*da0073e9SAndroid Build Coastguard Worker                super().__init__()
815*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
818*da0073e9SAndroid Build Coastguard Worker                x = self.a.view(2, 3)
819*da0073e9SAndroid Build Coastguard Worker                x[0][0] += 10
820*da0073e9SAndroid Build Coastguard Worker                return self.a.sum()
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
823*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
824*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
825*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
826*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
827*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
828*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("a"))
829*da0073e9SAndroid Build Coastguard Worker        m_f.a[0] -= 10
830*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_list(self):
834*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
835*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
836*da0073e9SAndroid Build Coastguard Worker                super().__init__()
837*da0073e9SAndroid Build Coastguard Worker                self.a = [torch.tensor([1, 2, 3, 4, 5, 6])]
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
840*da0073e9SAndroid Build Coastguard Worker                self.a[0][1] += 10
841*da0073e9SAndroid Build Coastguard Worker                return self.a[0].sum()
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
844*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
845*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
846*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
847*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
848*da0073e9SAndroid Build Coastguard Worker        m_s.a[0][1] -= 10
849*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
850*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m_f.hasattr("a"))
851*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_aliased_tensor_attr(self):
855*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
856*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
857*da0073e9SAndroid Build Coastguard Worker                super().__init__()
858*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
859*da0073e9SAndroid Build Coastguard Worker                self.b = self.a.view(2, 3)
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
862*da0073e9SAndroid Build Coastguard Worker                self.b[1] += 10
863*da0073e9SAndroid Build Coastguard Worker                return self.a.sum()
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
866*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
867*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
868*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
869*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("a"))
870*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
871*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
872*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(51)  # 1+2+3+14+15+16
873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_aliased_tensor_attr2(self):
876*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
877*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
878*da0073e9SAndroid Build Coastguard Worker                super().__init__()
879*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
880*da0073e9SAndroid Build Coastguard Worker                self.b = {"layer": ([self.a.view(2, 3), torch.tensor([10])], 20)}
881*da0073e9SAndroid Build Coastguard Worker                self.c = ([self.a.view(2, 3), torch.tensor([10])], 20)
882*da0073e9SAndroid Build Coastguard Worker                self.d = (self.a.view(2, 3), 20)
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
885*da0073e9SAndroid Build Coastguard Worker                self.d[0][0] += 10
886*da0073e9SAndroid Build Coastguard Worker                return self.a.sum()
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
889*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
890*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
891*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
892*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
893*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
894*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "module contains attributes values that overlaps"
895*da0073e9SAndroid Build Coastguard Worker        ):
896*da0073e9SAndroid Build Coastguard Worker            m_f = torch._C._freeze_module(m_s._c)
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_aliased_tensor_attr3(self):
899*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
900*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
901*da0073e9SAndroid Build Coastguard Worker                super().__init__()
902*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
903*da0073e9SAndroid Build Coastguard Worker                self.b = [self.a, torch.tensor([10])]
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
906*da0073e9SAndroid Build Coastguard Worker                self.a[1] += 10
907*da0073e9SAndroid Build Coastguard Worker                return self.b[0].sum()
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
910*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
911*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
912*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
913*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
914*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
915*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("a"))
916*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("b"))
917*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
918*da0073e9SAndroid Build Coastguard Worker        expected += 10  # account for  self.a += 10.
919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_aliased_tensor_attr4(self):
922*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
923*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
924*da0073e9SAndroid Build Coastguard Worker                super().__init__()
925*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
926*da0073e9SAndroid Build Coastguard Worker                self.b = [self.a, torch.tensor([10])]
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
929*da0073e9SAndroid Build Coastguard Worker                self.b[0][0] += 10
930*da0073e9SAndroid Build Coastguard Worker                return self.a.sum()
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
933*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
934*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
935*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
936*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
937*da0073e9SAndroid Build Coastguard Worker        m_s.a[0] -= 10
938*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
939*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "module contains attributes values that overlaps"
940*da0073e9SAndroid Build Coastguard Worker        ):
941*da0073e9SAndroid Build Coastguard Worker            m_f = torch._C._freeze_module(m_s._c)
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_overlapping_attrs(self):
944*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1, 2, 3, 4, 5, 6])
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
947*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
948*da0073e9SAndroid Build Coastguard Worker                super().__init__()
949*da0073e9SAndroid Build Coastguard Worker                self.b = [a.view(3, 2), torch.tensor([10])]
950*da0073e9SAndroid Build Coastguard Worker                self.c = (20, a.view(2, 3))
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
953*da0073e9SAndroid Build Coastguard Worker                self.b[0][0] += 10
954*da0073e9SAndroid Build Coastguard Worker                return self.c[1].sum()
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
957*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
958*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
959*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
960*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
961*da0073e9SAndroid Build Coastguard Worker        a[0] -= 10
962*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
963*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "module contains attributes values that overlaps"
964*da0073e9SAndroid Build Coastguard Worker        ):
965*da0073e9SAndroid Build Coastguard Worker            m_f = torch._C._freeze_module(m_s._c)
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_aliased_attr(self):
968*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
969*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
970*da0073e9SAndroid Build Coastguard Worker                super().__init__()
971*da0073e9SAndroid Build Coastguard Worker                self.a = [1, 2, 3, 4, 5, 6]
972*da0073e9SAndroid Build Coastguard Worker                self.b = self.a
973*da0073e9SAndroid Build Coastguard Worker                self.c = (self.a, 10)
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
976*da0073e9SAndroid Build Coastguard Worker                self.b[1] += 10
977*da0073e9SAndroid Build Coastguard Worker                return str(self.a) + str(self.c)
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
980*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
981*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
982*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
983*da0073e9SAndroid Build Coastguard Worker        # FIXME: It should be assertTrue. Currently scripting is making a copy for setting self.b (see #33034)
984*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m_f.hasattr("a"))
985*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m_f.hasattr("c"))
986*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
987*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
988*da0073e9SAndroid Build Coastguard Worker        expected = m_s.forward(inp)
989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker    # Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
992*da0073e9SAndroid Build Coastguard Worker    # In this example, 'a' is not mutated. However, we do not track which sub
993*da0073e9SAndroid Build Coastguard Worker    # values of a composite ivalue is mutated.
994*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_aliased_attr2(self):
995*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
996*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
997*da0073e9SAndroid Build Coastguard Worker                super().__init__()
998*da0073e9SAndroid Build Coastguard Worker                self.a = [1, 2, 3, 4, 5, 6]
999*da0073e9SAndroid Build Coastguard Worker                self.b = ([11], [10])
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1002*da0073e9SAndroid Build Coastguard Worker                v = self.a
1003*da0073e9SAndroid Build Coastguard Worker                self.b = (v, [12])
1004*da0073e9SAndroid Build Coastguard Worker                v2 = self.b[1]
1005*da0073e9SAndroid Build Coastguard Worker                v2.append(7)
1006*da0073e9SAndroid Build Coastguard Worker                return str(v) + str(v2)
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
1009*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1010*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1011*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
1012*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("a"))
1013*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
1014*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
1015*da0073e9SAndroid Build Coastguard Worker        expected = m.forward(inp)
1016*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_aliased_attr3(self):
1019*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
1020*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1021*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1022*da0073e9SAndroid Build Coastguard Worker                self.a = [1, 2, 3, 4, 5, 6]
1023*da0073e9SAndroid Build Coastguard Worker                self.b = ([11], [10])
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1026*da0073e9SAndroid Build Coastguard Worker                v = self.a
1027*da0073e9SAndroid Build Coastguard Worker                v2 = (v, [12])
1028*da0073e9SAndroid Build Coastguard Worker                v3 = v2[0]
1029*da0073e9SAndroid Build Coastguard Worker                v3.append(7)
1030*da0073e9SAndroid Build Coastguard Worker                return str(self.a)
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
1033*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1034*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1035*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
1036*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("a"))
1037*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([5])
1038*da0073e9SAndroid Build Coastguard Worker        out = m_f.forward(inp)
1039*da0073e9SAndroid Build Coastguard Worker        expected = m.forward(inp)
1040*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_return_self(self):
1043*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
1044*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1045*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1046*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.0, 2.0, 3.0])
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1049*da0073e9SAndroid Build Coastguard Worker                return self
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
1052*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1053*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1054*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1055*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "attempted to freeze a module that return itself"
1056*da0073e9SAndroid Build Coastguard Worker        ):
1057*da0073e9SAndroid Build Coastguard Worker            m_f = torch._C._freeze_module(m_s._c)
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_inlining(self):
1060*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script  # noqa: B903
1061*da0073e9SAndroid Build Coastguard Worker        class Obj:  # noqa: B903
1062*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x: int, y: int):
1063*da0073e9SAndroid Build Coastguard Worker                self.x = x
1064*da0073e9SAndroid Build Coastguard Worker                self.y = y
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
1067*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1068*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1069*da0073e9SAndroid Build Coastguard Worker                self.obj = Obj(2, 3)
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker            def forward(self, i: int):
1072*da0073e9SAndroid Build Coastguard Worker                print(self.obj)
1073*da0073e9SAndroid Build Coastguard Worker                return i
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.freeze(torch.jit.script(Mod().eval()))
1076*da0073e9SAndroid Build Coastguard Worker        obj = mod.graph.findNode("prim::Constant")
1077*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._jit_object_is_non_holding(obj))
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
1080*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(mod, buffer)
1081*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        loaded = torch.jit.load(buffer)
1084*da0073e9SAndroid Build Coastguard Worker        obj = mod.graph.findNode("prim::Constant")
1085*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._jit_object_is_non_holding(obj))
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_return_sub_module(self):
1088*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
1089*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1090*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1091*da0073e9SAndroid Build Coastguard Worker                self.conv1 = nn.Conv2d(1, 32, 3, 1)
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1094*da0073e9SAndroid Build Coastguard Worker                return self.conv1
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
1097*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1098*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1099*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c)
1100*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_f.hasattr("conv1"))
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_no_forward(self):
1103*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
1104*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1105*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1106*da0073e9SAndroid Build Coastguard Worker                self.lin = nn.Linear(10, 1)
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1109*da0073e9SAndroid Build Coastguard Worker            def foo(self, x):
1110*da0073e9SAndroid Build Coastguard Worker                return self.lin(x)
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
1113*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1114*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1115*da0073e9SAndroid Build Coastguard Worker        m_f = torch._C._freeze_module(m_s._c, preservedAttrs=["foo"])
1116*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(10)
1117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_s.foo(input), m_f.foo(input))
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker    def test_freeze_no_forward(self):
1120*da0073e9SAndroid Build Coastguard Worker        class FreezeMe(nn.Module):
1121*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1122*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1123*da0073e9SAndroid Build Coastguard Worker                self.lin = nn.Linear(10, 1)
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1126*da0073e9SAndroid Build Coastguard Worker            def foo(self, x):
1127*da0073e9SAndroid Build Coastguard Worker                return self.lin(x)
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker        m = FreezeMe()
1130*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1131*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1132*da0073e9SAndroid Build Coastguard Worker        m_f = torch.jit.freeze(m_s, preserved_attrs=["foo"])
1133*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(10)
1134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_s.foo(input), m_f.foo(input))
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_in_training_mode(self):
1137*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
1138*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1139*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1140*da0073e9SAndroid Build Coastguard Worker                self.conv1 = nn.Conv2d(1, 32, 3, 1)
1141*da0073e9SAndroid Build Coastguard Worker                self.conv2 = nn.Conv2d(32, 64, 3, 1)
1142*da0073e9SAndroid Build Coastguard Worker                self.dropout1 = nn.Dropout2d(0.25)
1143*da0073e9SAndroid Build Coastguard Worker                self.dropout2 = nn.Dropout2d(0.5)
1144*da0073e9SAndroid Build Coastguard Worker                self.fc1 = nn.Linear(9216, 128)
1145*da0073e9SAndroid Build Coastguard Worker                self.fc2 = nn.Linear(128, 10)
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1148*da0073e9SAndroid Build Coastguard Worker                x = self.conv1(x)
1149*da0073e9SAndroid Build Coastguard Worker                x = nn.functional.relu(x)
1150*da0073e9SAndroid Build Coastguard Worker                x = self.conv2(x)
1151*da0073e9SAndroid Build Coastguard Worker                x = nn.functional.max_pool2d(x, 2)
1152*da0073e9SAndroid Build Coastguard Worker                x = self.dropout1(x)
1153*da0073e9SAndroid Build Coastguard Worker                x = torch.flatten(x, 1)
1154*da0073e9SAndroid Build Coastguard Worker                x = self.fc1(x)
1155*da0073e9SAndroid Build Coastguard Worker                x = nn.functional.relu(x)
1156*da0073e9SAndroid Build Coastguard Worker                x = self.dropout2(x)
1157*da0073e9SAndroid Build Coastguard Worker                x = self.fc2(x)
1158*da0073e9SAndroid Build Coastguard Worker                output = nn.functional.log_softmax(x, dim=1)
1159*da0073e9SAndroid Build Coastguard Worker                return output
1160*da0073e9SAndroid Build Coastguard Worker
1161*da0073e9SAndroid Build Coastguard Worker        model = torch.jit.script(Net())
1162*da0073e9SAndroid Build Coastguard Worker        model.train()
1163*da0073e9SAndroid Build Coastguard Worker        mTrain_freezed = torch._C._freeze_module(model._c)
1164*da0073e9SAndroid Build Coastguard Worker        # verify mTrain_freezed looks exactly as:
1165*da0073e9SAndroid Build Coastguard Worker        # module {
1166*da0073e9SAndroid Build Coastguard Worker        #   attributes {
1167*da0073e9SAndroid Build Coastguard Worker        #     conv1 = ...
1168*da0073e9SAndroid Build Coastguard Worker        #     conv2 = ...
1169*da0073e9SAndroid Build Coastguard Worker        #     dropout1 = ...
1170*da0073e9SAndroid Build Coastguard Worker        #     dropout2 = ...
1171*da0073e9SAndroid Build Coastguard Worker        #     fc1 = ...
1172*da0073e9SAndroid Build Coastguard Worker        #     fc2 = ...
1173*da0073e9SAndroid Build Coastguard Worker        #   }
1174*da0073e9SAndroid Build Coastguard Worker        #   ...
1175*da0073e9SAndroid Build Coastguard Worker        #   submodules {
1176*da0073e9SAndroid Build Coastguard Worker        #     module conv1 {
1177*da0073e9SAndroid Build Coastguard Worker        #       attributes {
1178*da0073e9SAndroid Build Coastguard Worker        #          weight = ...
1179*da0073e9SAndroid Build Coastguard Worker        #          bias = ...
1180*da0073e9SAndroid Build Coastguard Worker        #       }
1181*da0073e9SAndroid Build Coastguard Worker        #       ...
1182*da0073e9SAndroid Build Coastguard Worker        #     }
1183*da0073e9SAndroid Build Coastguard Worker        #     module conv2 {
1184*da0073e9SAndroid Build Coastguard Worker        #       attributes {
1185*da0073e9SAndroid Build Coastguard Worker        #          weight = ...
1186*da0073e9SAndroid Build Coastguard Worker        #          bias = ...
1187*da0073e9SAndroid Build Coastguard Worker        #       }
1188*da0073e9SAndroid Build Coastguard Worker        #       ...
1189*da0073e9SAndroid Build Coastguard Worker        #     }
1190*da0073e9SAndroid Build Coastguard Worker        #     module dropout1 {
1191*da0073e9SAndroid Build Coastguard Worker        #       attributes {
1192*da0073e9SAndroid Build Coastguard Worker        #          training = ...
1193*da0073e9SAndroid Build Coastguard Worker        #       }
1194*da0073e9SAndroid Build Coastguard Worker        #       ...
1195*da0073e9SAndroid Build Coastguard Worker        #     }
1196*da0073e9SAndroid Build Coastguard Worker        #     module dropout2 {
1197*da0073e9SAndroid Build Coastguard Worker        #       attributes {
1198*da0073e9SAndroid Build Coastguard Worker        #          training = ...
1199*da0073e9SAndroid Build Coastguard Worker        #       }
1200*da0073e9SAndroid Build Coastguard Worker        #       ...
1201*da0073e9SAndroid Build Coastguard Worker        #     }
1202*da0073e9SAndroid Build Coastguard Worker        #     module fc1 {
1203*da0073e9SAndroid Build Coastguard Worker        #       attributes {
1204*da0073e9SAndroid Build Coastguard Worker        #          weight = ...
1205*da0073e9SAndroid Build Coastguard Worker        #          bias = ...
1206*da0073e9SAndroid Build Coastguard Worker        #       }
1207*da0073e9SAndroid Build Coastguard Worker        #       ...
1208*da0073e9SAndroid Build Coastguard Worker        #     }
1209*da0073e9SAndroid Build Coastguard Worker        #     module fc2 {
1210*da0073e9SAndroid Build Coastguard Worker        #       attributes {
1211*da0073e9SAndroid Build Coastguard Worker        #          weight = ...
1212*da0073e9SAndroid Build Coastguard Worker        #          bias = ...
1213*da0073e9SAndroid Build Coastguard Worker        #       }
1214*da0073e9SAndroid Build Coastguard Worker        #       ...
1215*da0073e9SAndroid Build Coastguard Worker        #     }
1216*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mTrain_freezed.hasattr("training"))
1217*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.hasattr("conv1"))
1218*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mTrain_freezed.conv1.hasattr("training"))
1219*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.conv1.hasattr("weight"))
1220*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.conv1.hasattr("bias"))
1221*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.hasattr("conv2"))
1222*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mTrain_freezed.conv2.hasattr("training"))
1223*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.conv2.hasattr("weight"))
1224*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.conv2.hasattr("bias"))
1225*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.hasattr("dropout1"))
1226*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.dropout1.hasattr("training"))
1227*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.hasattr("dropout2"))
1228*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.dropout2.hasattr("training"))
1229*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.hasattr("fc1"))
1230*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.fc1.hasattr("weight"))
1231*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.fc1.hasattr("bias"))
1232*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.hasattr("fc2"))
1233*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.fc2.hasattr("weight"))
1234*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mTrain_freezed.fc2.hasattr("bias"))
1235*da0073e9SAndroid Build Coastguard Worker        model.eval()
1236*da0073e9SAndroid Build Coastguard Worker        mEval_freezed = torch._C._freeze_module(model._c)
1237*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mEval_freezed.hasattr("conv1"))
1238*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mEval_freezed.hasattr("conv2"))
1239*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mEval_freezed.hasattr("dropout1"))
1240*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mEval_freezed.hasattr("training"))
1241*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mEval_freezed.hasattr("fc1"))
1242*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mEval_freezed.hasattr("dropout2"))
1243*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mEval_freezed.hasattr("fc2"))
1244*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1245*da0073e9SAndroid Build Coastguard Worker            AttributeError, "does not have a field with name 'state_dict'"
1246*da0073e9SAndroid Build Coastguard Worker        ):
1247*da0073e9SAndroid Build Coastguard Worker            print(mEval_freezed.state_dict())
1248*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
1249*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(mEval_freezed, buffer)
1250*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
1251*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.load(buffer)
1252*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("GetAttr[name=").run(m._c._get_method("forward").graph)
1253*da0073e9SAndroid Build Coastguard Worker        m2 = torch._C._freeze_module(model._c, preserveParameters=True)
1254*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.hasattr("conv1"))
1255*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.hasattr("conv2"))
1256*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2.hasattr("dropout1"))
1257*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2.hasattr("training"))
1258*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.hasattr("fc1"))
1259*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2.hasattr("dropout2"))
1260*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.hasattr("fc2"))
1261*da0073e9SAndroid Build Coastguard Worker
1262*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_detach_gradient(self):
1263*da0073e9SAndroid Build Coastguard Worker        mod = nn.Conv2d(8, 3, 4, 2, 1)
1264*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mod.weight.requires_grad)
1265*da0073e9SAndroid Build Coastguard Worker        smod = torch.jit.script(mod)
1266*da0073e9SAndroid Build Coastguard Worker        smod.eval()
1267*da0073e9SAndroid Build Coastguard Worker        fmod = torch._C._freeze_module(smod._c)
1268*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mod.weight.requires_grad)
1269*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(smod.weight.requires_grad)
1270*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fmod.hasattr("weight"))
1271*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(1, 8, 32, 32)
1272*da0073e9SAndroid Build Coastguard Worker        out1 = fmod.forward(inp)
1273*da0073e9SAndroid Build Coastguard Worker        # FIXME: frozen module mutated from outside (original module).
1274*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1275*da0073e9SAndroid Build Coastguard Worker            smod.weight[0, 0, 0, 0] += 100.0
1276*da0073e9SAndroid Build Coastguard Worker        out2 = fmod.forward(inp)
1277*da0073e9SAndroid Build Coastguard Worker        out3 = smod(inp)
1278*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(out1, out2)
1279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out2, out3)
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_user_preserved_attr(self):
1282*da0073e9SAndroid Build Coastguard Worker        class Module(nn.Module):
1283*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1284*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1285*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.1])
1286*da0073e9SAndroid Build Coastguard Worker                self.b = torch.tensor([2.2])
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1289*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Module())
1292*da0073e9SAndroid Build Coastguard Worker        m.eval()
1293*da0073e9SAndroid Build Coastguard Worker        fm = torch._C._freeze_module(m._c, ["a"])
1294*da0073e9SAndroid Build Coastguard Worker        # Attribute "a" is preserved
1295*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.hasattr("a"))
1296*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fm.hasattr("b"))
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_user_preserved_method(self):
1299*da0073e9SAndroid Build Coastguard Worker        class Module(nn.Module):
1300*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1301*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1302*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.1])
1303*da0073e9SAndroid Build Coastguard Worker                self.b = torch.tensor([2.2])
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1306*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1309*da0073e9SAndroid Build Coastguard Worker            def modify_a(self, x):
1310*da0073e9SAndroid Build Coastguard Worker                self.a[0] += 10
1311*da0073e9SAndroid Build Coastguard Worker                return self.b
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1314*da0073e9SAndroid Build Coastguard Worker            def modify_b(self, x):
1315*da0073e9SAndroid Build Coastguard Worker                self.b[0] += 20
1316*da0073e9SAndroid Build Coastguard Worker                return self.a
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Module())
1319*da0073e9SAndroid Build Coastguard Worker        m.eval()
1320*da0073e9SAndroid Build Coastguard Worker        fm = torch._C._freeze_module(m._c, ["modify_a"])
1321*da0073e9SAndroid Build Coastguard Worker        # Both attribute "a" and method "modify_a" are preserved
1322*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.hasattr("a"))
1323*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fm.hasattr("b"))
1324*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
1325*da0073e9SAndroid Build Coastguard Worker        expected = m.forward(input)
1326*da0073e9SAndroid Build Coastguard Worker        out = fm.forward(input)
1327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_user_preserved_method2(self):
1330*da0073e9SAndroid Build Coastguard Worker        class Module(nn.Module):
1331*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1332*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1333*da0073e9SAndroid Build Coastguard Worker                self.a = torch.tensor([1.1])
1334*da0073e9SAndroid Build Coastguard Worker                self.b = torch.tensor([2.2])
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1337*da0073e9SAndroid Build Coastguard Worker                self.b += 10
1338*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
1339*da0073e9SAndroid Build Coastguard Worker
1340*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1341*da0073e9SAndroid Build Coastguard Worker            def modify_a(self, x):
1342*da0073e9SAndroid Build Coastguard Worker                self.a[0] += 10
1343*da0073e9SAndroid Build Coastguard Worker                return self.b + self.a
1344*da0073e9SAndroid Build Coastguard Worker
1345*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Module())
1346*da0073e9SAndroid Build Coastguard Worker        m.eval()
1347*da0073e9SAndroid Build Coastguard Worker        fm = torch._C._freeze_module(m._c, ["modify_a"])
1348*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
1349*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_user_preserved_attribute_on_submodule(self):
1352*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
1353*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1354*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1355*da0073e9SAndroid Build Coastguard Worker                self.a = 1
1356*da0073e9SAndroid Build Coastguard Worker                self.b = 2
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker            def forward(self):
1359*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker        class Module(nn.Module):
1362*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1363*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1364*da0073e9SAndroid Build Coastguard Worker                self.sub1 = SubModule()
1365*da0073e9SAndroid Build Coastguard Worker                self.sub2 = SubModule()
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker            def forward(self):
1368*da0073e9SAndroid Build Coastguard Worker                return self.sub1() + self.sub2()
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Module())
1371*da0073e9SAndroid Build Coastguard Worker        m.eval()
1372*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.freeze(m, preserved_attrs=["sub1.a", "sub2.a"])
1373*da0073e9SAndroid Build Coastguard Worker        fm = m._c
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.hasattr("sub1"))
1376*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.sub1.hasattr("a"))
1377*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fm.sub1.hasattr("b"))
1378*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.hasattr("sub2"))
1379*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.sub2.hasattr("a"))
1380*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fm.sub2.hasattr("b"))
1381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(), 6)
1382*da0073e9SAndroid Build Coastguard Worker        m.sub1.a += 1
1383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(), 7)
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self):
1386*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
1387*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1388*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1389*da0073e9SAndroid Build Coastguard Worker                self.a = 1
1390*da0073e9SAndroid Build Coastguard Worker                self.b = 2
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker            def forward(self):
1393*da0073e9SAndroid Build Coastguard Worker                return self.a + self.b
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1396*da0073e9SAndroid Build Coastguard Worker            def method_a(self):
1397*da0073e9SAndroid Build Coastguard Worker                return 42
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker        class Module(nn.Module):
1400*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1401*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1402*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker            def forward(self):
1405*da0073e9SAndroid Build Coastguard Worker                return 1
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Module())
1408*da0073e9SAndroid Build Coastguard Worker        m.eval()
1409*da0073e9SAndroid Build Coastguard Worker        fm = torch.jit.freeze(m, preserved_attrs=["sub.a", "sub.method_a"])._c
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.hasattr("sub"))
1412*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.sub.hasattr("a"))
1413*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fm.sub.hasattr("b"))
1414*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.sub._has_method("method_a"))
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_user_preserved_method_on_submodule(self):
1417*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
1418*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1419*da0073e9SAndroid Build Coastguard Worker                return self.method_a(x) + self.method_b(x)
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker            def method_a(self, x):
1422*da0073e9SAndroid Build Coastguard Worker                return x * x
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker            def method_b(self, x):
1425*da0073e9SAndroid Build Coastguard Worker                return x + x
1426*da0073e9SAndroid Build Coastguard Worker
1427*da0073e9SAndroid Build Coastguard Worker        class Module(nn.Module):
1428*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1429*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1430*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1433*da0073e9SAndroid Build Coastguard Worker                return self.sub(x)
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Module())
1436*da0073e9SAndroid Build Coastguard Worker        m.eval()
1437*da0073e9SAndroid Build Coastguard Worker        fm = torch.jit.freeze(m, preserved_attrs=["sub.method_a"])._c
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.hasattr("sub"))
1440*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fm.sub._has_method("method_a"))
1441*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fm.sub._has_method("method_b"))
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker    @skipIfNoFBGEMM
1444*da0073e9SAndroid Build Coastguard Worker    def test_module_with_shared_type_instances(self):
1445*da0073e9SAndroid Build Coastguard Worker        class Child(nn.Module):
1446*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1447*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1448*da0073e9SAndroid Build Coastguard Worker                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1451*da0073e9SAndroid Build Coastguard Worker                x = self.conv1(x)
1452*da0073e9SAndroid Build Coastguard Worker                return x
1453*da0073e9SAndroid Build Coastguard Worker
1454*da0073e9SAndroid Build Coastguard Worker        class Parent(nn.Module):
1455*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1456*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1457*da0073e9SAndroid Build Coastguard Worker                self.quant = torch.ao.quantization.QuantStub()
1458*da0073e9SAndroid Build Coastguard Worker                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
1459*da0073e9SAndroid Build Coastguard Worker                self.child = Child()
1460*da0073e9SAndroid Build Coastguard Worker                self.child2 = Child()
1461*da0073e9SAndroid Build Coastguard Worker                self.dequant = torch.ao.quantization.DeQuantStub()
1462*da0073e9SAndroid Build Coastguard Worker
1463*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1464*da0073e9SAndroid Build Coastguard Worker                x = self.quant(x)
1465*da0073e9SAndroid Build Coastguard Worker                x = self.conv1(x)
1466*da0073e9SAndroid Build Coastguard Worker                x = self.child(x)
1467*da0073e9SAndroid Build Coastguard Worker                x = self.child2(x)
1468*da0073e9SAndroid Build Coastguard Worker                x = self.dequant(x)
1469*da0073e9SAndroid Build Coastguard Worker                return x
1470*da0073e9SAndroid Build Coastguard Worker
1471*da0073e9SAndroid Build Coastguard Worker        def _static_quant(model):
1472*da0073e9SAndroid Build Coastguard Worker            qModel = torch.ao.quantization.QuantWrapper(model)
1473*da0073e9SAndroid Build Coastguard Worker            qModel.qconfig = torch.ao.quantization.default_qconfig
1474*da0073e9SAndroid Build Coastguard Worker            torch.ao.quantization.prepare(qModel, inplace=True)
1475*da0073e9SAndroid Build Coastguard Worker            qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32))
1476*da0073e9SAndroid Build Coastguard Worker            torch.ao.quantization.convert(qModel, inplace=True)
1477*da0073e9SAndroid Build Coastguard Worker            return model
1478*da0073e9SAndroid Build Coastguard Worker
1479*da0073e9SAndroid Build Coastguard Worker        with override_quantized_engine("fbgemm"):
1480*da0073e9SAndroid Build Coastguard Worker            data = torch.randn(4, 1, 4, 4, dtype=torch.float32)
1481*da0073e9SAndroid Build Coastguard Worker            m = Parent().to(torch.float32)
1482*da0073e9SAndroid Build Coastguard Worker            m = _static_quant(m)
1483*da0073e9SAndroid Build Coastguard Worker            m = torch.jit.script(m)
1484*da0073e9SAndroid Build Coastguard Worker            m.eval()
1485*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_inline(m.graph)
1486*da0073e9SAndroid Build Coastguard Worker            m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c))
1487*da0073e9SAndroid Build Coastguard Worker            # Earlier bug resulted in _packed_params set to false.
1488*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("_packed_params = False").run(
1489*da0073e9SAndroid Build Coastguard Worker                m_frozen._c.dump_to_str(True, True, False)
1490*da0073e9SAndroid Build Coastguard Worker            )
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker            m_res = m(data)
1493*da0073e9SAndroid Build Coastguard Worker            # It used to segfault while running frozen module.
1494*da0073e9SAndroid Build Coastguard Worker            m_frozen_res = m_frozen(data)
1495*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_res, m_frozen_res)
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker    def test_module_getattr_indirection(self):
1498*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1499*da0073e9SAndroid Build Coastguard Worker        class ValHolder:
1500*da0073e9SAndroid Build Coastguard Worker            def __init__(self, val: int):
1501*da0073e9SAndroid Build Coastguard Worker                self.val: int = val
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
1504*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1505*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1506*da0073e9SAndroid Build Coastguard Worker                self.mod1 = ValHolder(1)
1507*da0073e9SAndroid Build Coastguard Worker                self.mod2 = ValHolder(2)
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker            def forward(self, cond: bool):
1510*da0073e9SAndroid Build Coastguard Worker                if cond:
1511*da0073e9SAndroid Build Coastguard Worker                    mod = self.mod1
1512*da0073e9SAndroid Build Coastguard Worker                else:
1513*da0073e9SAndroid Build Coastguard Worker                    mod = self.mod2
1514*da0073e9SAndroid Build Coastguard Worker                return mod.val
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
1517*da0073e9SAndroid Build Coastguard Worker        mod.eval()
1518*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
1519*da0073e9SAndroid Build Coastguard Worker        mod_eager = Mod()
1520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod_eager(True), frozen_mod(True))
1521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod_eager(False), frozen_mod(False))
1522*da0073e9SAndroid Build Coastguard Worker
1523*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_non_static_module_container_index(self):
1524*da0073e9SAndroid Build Coastguard Worker        """
1525*da0073e9SAndroid Build Coastguard Worker        Test that Modules containing non-static ModuleDict or ModuleList
1526*da0073e9SAndroid Build Coastguard Worker        indexing cannot be frozen.
1527*da0073e9SAndroid Build Coastguard Worker        """
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1530*da0073e9SAndroid Build Coastguard Worker        class ModuleInterface(torch.nn.Module):
1531*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: Any) -> Any:
1532*da0073e9SAndroid Build Coastguard Worker                pass
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        class ImplementsInterface(torch.nn.Module):
1535*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: Any) -> Any:
1536*da0073e9SAndroid Build Coastguard Worker                if isinstance(inp, torch.Tensor):
1537*da0073e9SAndroid Build Coastguard Worker                    return torch.max(inp, dim=0)
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker                return inp
1540*da0073e9SAndroid Build Coastguard Worker
1541*da0073e9SAndroid Build Coastguard Worker        class ModWithDict(torch.nn.Module):
1542*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1543*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1544*da0073e9SAndroid Build Coastguard Worker                self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, key: str) -> Any:
1547*da0073e9SAndroid Build Coastguard Worker                value: ModuleInterface = self.d[key]
1548*da0073e9SAndroid Build Coastguard Worker                return value.forward(x)
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(ModWithDict())
1551*da0073e9SAndroid Build Coastguard Worker        m.eval()
1552*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1553*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1554*da0073e9SAndroid Build Coastguard Worker            "Freezing modules containing prim::ModuleContainerIndex is not supported",
1555*da0073e9SAndroid Build Coastguard Worker        ):
1556*da0073e9SAndroid Build Coastguard Worker            mf = torch._C._freeze_module(m._c)
1557*da0073e9SAndroid Build Coastguard Worker
1558*da0073e9SAndroid Build Coastguard Worker        class ModWithList(torch.nn.Module):
1559*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1560*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1561*da0073e9SAndroid Build Coastguard Worker                self.l = torch.nn.ModuleList([ImplementsInterface()])
1562*da0073e9SAndroid Build Coastguard Worker
1563*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, idx: int) -> Any:
1564*da0073e9SAndroid Build Coastguard Worker                value: ModuleInterface = self.l[idx]
1565*da0073e9SAndroid Build Coastguard Worker                return value.forward(x)
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(ModWithList())
1568*da0073e9SAndroid Build Coastguard Worker        m.eval()
1569*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1570*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1571*da0073e9SAndroid Build Coastguard Worker            "Freezing modules containing prim::ModuleContainerIndex is not supported",
1572*da0073e9SAndroid Build Coastguard Worker        ):
1573*da0073e9SAndroid Build Coastguard Worker            mf = torch._C._freeze_module(m._c)
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker    def test_freeze_with_interface_mutable(self):
1576*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1577*da0073e9SAndroid Build Coastguard Worker        class ModuleInterface(torch.nn.Module):
1578*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1579*da0073e9SAndroid Build Coastguard Worker                pass
1580*da0073e9SAndroid Build Coastguard Worker
1581*da0073e9SAndroid Build Coastguard Worker        class ImplementsInterface(torch.nn.Module):
1582*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1583*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1584*da0073e9SAndroid Build Coastguard Worker                self.sum = torch.zeros((2, 2))
1585*da0073e9SAndroid Build Coastguard Worker
1586*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1587*da0073e9SAndroid Build Coastguard Worker                self.sum += inp.relu()
1588*da0073e9SAndroid Build Coastguard Worker                return self.sum
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
1591*da0073e9SAndroid Build Coastguard Worker            impl: ModuleInterface
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1594*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1595*da0073e9SAndroid Build Coastguard Worker                self.impl = ImplementsInterface()
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1598*da0073e9SAndroid Build Coastguard Worker                return self.impl.forward(x)
1599*da0073e9SAndroid Build Coastguard Worker
1600*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(WrapperModule())
1601*da0073e9SAndroid Build Coastguard Worker        m.eval()
1602*da0073e9SAndroid Build Coastguard Worker        m_frozen = torch.jit.freeze(m)
1603*da0073e9SAndroid Build Coastguard Worker
1604*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker        m_frozen(x)
1607*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_frozen.impl.sum, x.relu())
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker    def test_freeze_with_swapping_interfaces(self):
1610*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1611*da0073e9SAndroid Build Coastguard Worker        class ModuleInterface(torch.nn.Module):
1612*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1613*da0073e9SAndroid Build Coastguard Worker                pass
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker        class Implementation1(torch.nn.Module):
1616*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1617*da0073e9SAndroid Build Coastguard Worker                return inp.relu()
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker        class Implementation2(torch.nn.Module):
1620*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1621*da0073e9SAndroid Build Coastguard Worker                return inp.sin()
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
1624*da0073e9SAndroid Build Coastguard Worker            impl: ModuleInterface
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1627*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1628*da0073e9SAndroid Build Coastguard Worker                self.option1 = Implementation1()
1629*da0073e9SAndroid Build Coastguard Worker                self.option2 = Implementation2()
1630*da0073e9SAndroid Build Coastguard Worker                self.impl = self.option1
1631*da0073e9SAndroid Build Coastguard Worker                self.idx = 0
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1634*da0073e9SAndroid Build Coastguard Worker                self.idx += 1
1635*da0073e9SAndroid Build Coastguard Worker                if self.idx % 2 == 1:
1636*da0073e9SAndroid Build Coastguard Worker                    self.impl = self.option1
1637*da0073e9SAndroid Build Coastguard Worker                else:
1638*da0073e9SAndroid Build Coastguard Worker                    self.impl = self.option2
1639*da0073e9SAndroid Build Coastguard Worker                return self.impl(x)
1640*da0073e9SAndroid Build Coastguard Worker
1641*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(WrapperModule())
1642*da0073e9SAndroid Build Coastguard Worker        m.eval()
1643*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1644*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Freezing does not support SetAttr on an interface type"
1645*da0073e9SAndroid Build Coastguard Worker        ):
1646*da0073e9SAndroid Build Coastguard Worker            m_frozen = torch.jit.freeze(m)
1647*da0073e9SAndroid Build Coastguard Worker
1648*da0073e9SAndroid Build Coastguard Worker    def test_freeze_recursive_interfaces(self):
1649*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1650*da0073e9SAndroid Build Coastguard Worker        class InnerInterface(torch.nn.Module):
1651*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1652*da0073e9SAndroid Build Coastguard Worker                pass
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1655*da0073e9SAndroid Build Coastguard Worker        class OuterInterface(torch.nn.Module):
1656*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1657*da0073e9SAndroid Build Coastguard Worker                pass
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker        class InnerImpl(torch.nn.Module):
1660*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1661*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1662*da0073e9SAndroid Build Coastguard Worker                self.x = torch.ones((2, 2))
1663*da0073e9SAndroid Build Coastguard Worker
1664*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1665*da0073e9SAndroid Build Coastguard Worker                return inp.cos() * self.x
1666*da0073e9SAndroid Build Coastguard Worker
1667*da0073e9SAndroid Build Coastguard Worker        class OuterImpl(torch.nn.Module):
1668*da0073e9SAndroid Build Coastguard Worker            inner_impl: InnerInterface
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1671*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1672*da0073e9SAndroid Build Coastguard Worker                self.inner_impl = InnerImpl()
1673*da0073e9SAndroid Build Coastguard Worker
1674*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1675*da0073e9SAndroid Build Coastguard Worker                return inp.relu() + self.inner_impl(inp.sin())
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
1678*da0073e9SAndroid Build Coastguard Worker            outer_impl: OuterInterface
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1681*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1682*da0073e9SAndroid Build Coastguard Worker                self.outer_impl = OuterImpl()
1683*da0073e9SAndroid Build Coastguard Worker
1684*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1685*da0073e9SAndroid Build Coastguard Worker                return self.outer_impl(inp) + inp
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker        m = WrapperModule()
1688*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
1689*da0073e9SAndroid Build Coastguard Worker        expected = m(x)
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1692*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1693*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.freeze(m_s)
1694*da0073e9SAndroid Build Coastguard Worker        actual = m_s(x)
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker    def test_freeze_recursive_interfaces_with_reassignment(self):
1699*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1700*da0073e9SAndroid Build Coastguard Worker        class InnerInterface(torch.nn.Module):
1701*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1702*da0073e9SAndroid Build Coastguard Worker                pass
1703*da0073e9SAndroid Build Coastguard Worker
1704*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1705*da0073e9SAndroid Build Coastguard Worker        class OuterInterface(torch.nn.Module):
1706*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1707*da0073e9SAndroid Build Coastguard Worker                pass
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker        class InnerImpl1(torch.nn.Module):
1710*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1711*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1712*da0073e9SAndroid Build Coastguard Worker                self.x = torch.ones((2, 2))
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1715*da0073e9SAndroid Build Coastguard Worker                return inp.cos() * self.x
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker        class InnerImpl2(torch.nn.Module):
1718*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1719*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1720*da0073e9SAndroid Build Coastguard Worker                self.x = torch.ones((2, 2)) * 2
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1723*da0073e9SAndroid Build Coastguard Worker                return inp.sin() / self.x
1724*da0073e9SAndroid Build Coastguard Worker
1725*da0073e9SAndroid Build Coastguard Worker        class OuterImpl(torch.nn.Module):
1726*da0073e9SAndroid Build Coastguard Worker            inner_impl: InnerInterface
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1729*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1730*da0073e9SAndroid Build Coastguard Worker                self.inner_impl = InnerImpl1()
1731*da0073e9SAndroid Build Coastguard Worker                self.impl1 = InnerImpl1()
1732*da0073e9SAndroid Build Coastguard Worker                self.impl2 = InnerImpl1()
1733*da0073e9SAndroid Build Coastguard Worker                self.idx = 0
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1736*da0073e9SAndroid Build Coastguard Worker                self.idx += 1
1737*da0073e9SAndroid Build Coastguard Worker                if self.idx % 2 == 0:
1738*da0073e9SAndroid Build Coastguard Worker                    self.inner_impl = self.impl1
1739*da0073e9SAndroid Build Coastguard Worker                else:
1740*da0073e9SAndroid Build Coastguard Worker                    self.inner_impl = self.impl2
1741*da0073e9SAndroid Build Coastguard Worker                return inp.relu() + self.inner_impl(inp.sin())
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
1744*da0073e9SAndroid Build Coastguard Worker            outer_impl: OuterInterface
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1747*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1748*da0073e9SAndroid Build Coastguard Worker                self.outer_impl = OuterImpl()
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1751*da0073e9SAndroid Build Coastguard Worker                return self.outer_impl(inp) + inp
1752*da0073e9SAndroid Build Coastguard Worker
1753*da0073e9SAndroid Build Coastguard Worker        m = WrapperModule()
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1756*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1757*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1758*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Freezing does not support SetAttr on an interface type"
1759*da0073e9SAndroid Build Coastguard Worker        ):
1760*da0073e9SAndroid Build Coastguard Worker            m_s = torch.jit.freeze(m_s)
1761*da0073e9SAndroid Build Coastguard Worker
1762*da0073e9SAndroid Build Coastguard Worker    def test_freeze_interface_swapping_two_methods(self):
1763*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1764*da0073e9SAndroid Build Coastguard Worker        class MyInterface(torch.nn.Module):
1765*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1766*da0073e9SAndroid Build Coastguard Worker                pass
1767*da0073e9SAndroid Build Coastguard Worker
1768*da0073e9SAndroid Build Coastguard Worker        class Impl1(torch.nn.Module):
1769*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1770*da0073e9SAndroid Build Coastguard Worker                return inp.cos()
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker        class Impl2(torch.nn.Module):
1773*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1774*da0073e9SAndroid Build Coastguard Worker                return inp.sin()
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker        class WrapperModule1(torch.nn.Module):
1777*da0073e9SAndroid Build Coastguard Worker            interface_impl: MyInterface
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1780*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1781*da0073e9SAndroid Build Coastguard Worker                self.interface_impl = Impl1()
1782*da0073e9SAndroid Build Coastguard Worker                self.impl1 = Impl1()
1783*da0073e9SAndroid Build Coastguard Worker                self.impl2 = Impl2()
1784*da0073e9SAndroid Build Coastguard Worker                self.idx = 0
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1787*da0073e9SAndroid Build Coastguard Worker                return self.interface_impl(x)
1788*da0073e9SAndroid Build Coastguard Worker
1789*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1790*da0073e9SAndroid Build Coastguard Worker            def other_method(self, x):
1791*da0073e9SAndroid Build Coastguard Worker                self.idx += 1
1792*da0073e9SAndroid Build Coastguard Worker                if self.idx % 2 == 0:
1793*da0073e9SAndroid Build Coastguard Worker                    self.interface_impl = self.impl1
1794*da0073e9SAndroid Build Coastguard Worker                else:
1795*da0073e9SAndroid Build Coastguard Worker                    self.interface_impl = self.impl2
1796*da0073e9SAndroid Build Coastguard Worker                return self.interface_impl(x)
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker        class WrapperModule2(torch.nn.Module):
1799*da0073e9SAndroid Build Coastguard Worker            interface_impl: MyInterface
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1802*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1803*da0073e9SAndroid Build Coastguard Worker                self.interface_impl = Impl1()
1804*da0073e9SAndroid Build Coastguard Worker                self.impl1 = Impl1()
1805*da0073e9SAndroid Build Coastguard Worker                self.impl2 = Impl2()
1806*da0073e9SAndroid Build Coastguard Worker                self.idx = 0
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1809*da0073e9SAndroid Build Coastguard Worker                self.idx += 1
1810*da0073e9SAndroid Build Coastguard Worker                if self.idx % 2 == 0:
1811*da0073e9SAndroid Build Coastguard Worker                    self.interface_impl = self.impl1
1812*da0073e9SAndroid Build Coastguard Worker                else:
1813*da0073e9SAndroid Build Coastguard Worker                    self.interface_impl = self.impl2
1814*da0073e9SAndroid Build Coastguard Worker                return self.interface_impl(x)
1815*da0073e9SAndroid Build Coastguard Worker
1816*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1817*da0073e9SAndroid Build Coastguard Worker            def other_method(self, x):
1818*da0073e9SAndroid Build Coastguard Worker                return self.interface_impl(x)
1819*da0073e9SAndroid Build Coastguard Worker
1820*da0073e9SAndroid Build Coastguard Worker        m1 = torch.jit.script(WrapperModule1())
1821*da0073e9SAndroid Build Coastguard Worker        m2 = torch.jit.script(WrapperModule2())
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker        m1.eval()
1824*da0073e9SAndroid Build Coastguard Worker        m2.eval()
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1827*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Freezing does not support SetAttr on an interface type"
1828*da0073e9SAndroid Build Coastguard Worker        ):
1829*da0073e9SAndroid Build Coastguard Worker            torch.jit.freeze(m1, preserved_attrs=["other_method"])
1830*da0073e9SAndroid Build Coastguard Worker
1831*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1832*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Freezing does not support SetAttr on an interface type"
1833*da0073e9SAndroid Build Coastguard Worker        ):
1834*da0073e9SAndroid Build Coastguard Worker            torch.jit.freeze(m2, preserved_attrs=["other_method"])
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker    def test_freeze_recursive_interfaces_same_name(self):
1837*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1838*da0073e9SAndroid Build Coastguard Worker        class InnerInterface(torch.nn.Module):
1839*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1840*da0073e9SAndroid Build Coastguard Worker                pass
1841*da0073e9SAndroid Build Coastguard Worker
1842*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
1843*da0073e9SAndroid Build Coastguard Worker        class OuterInterface(torch.nn.Module):
1844*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1845*da0073e9SAndroid Build Coastguard Worker                pass
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker        class InnerImpl(torch.nn.Module):
1848*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1849*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1850*da0073e9SAndroid Build Coastguard Worker                self.x = torch.ones((2, 2))
1851*da0073e9SAndroid Build Coastguard Worker
1852*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1853*da0073e9SAndroid Build Coastguard Worker                return inp.cos() * self.x
1854*da0073e9SAndroid Build Coastguard Worker
1855*da0073e9SAndroid Build Coastguard Worker        class OuterImpl(torch.nn.Module):
1856*da0073e9SAndroid Build Coastguard Worker            impl: InnerInterface
1857*da0073e9SAndroid Build Coastguard Worker
1858*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1859*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1860*da0073e9SAndroid Build Coastguard Worker                self.impl = InnerImpl()
1861*da0073e9SAndroid Build Coastguard Worker                self.x = torch.ones((2, 2)) * 5
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1864*da0073e9SAndroid Build Coastguard Worker                return self.other_method(inp)
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker            def other_method(self, inp):
1867*da0073e9SAndroid Build Coastguard Worker                return inp.relu() + self.impl(inp.sin()) + self.x
1868*da0073e9SAndroid Build Coastguard Worker
1869*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
1870*da0073e9SAndroid Build Coastguard Worker            impl: OuterInterface
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1873*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1874*da0073e9SAndroid Build Coastguard Worker                self.impl = OuterImpl()
1875*da0073e9SAndroid Build Coastguard Worker
1876*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
1877*da0073e9SAndroid Build Coastguard Worker                return self.impl(inp) + inp
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Worker        m = WrapperModule()
1880*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
1881*da0073e9SAndroid Build Coastguard Worker        expected = m(x)
1882*da0073e9SAndroid Build Coastguard Worker
1883*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.script(m)
1884*da0073e9SAndroid Build Coastguard Worker        m_s.eval()
1885*da0073e9SAndroid Build Coastguard Worker        m_s = torch.jit.freeze(m_s)
1886*da0073e9SAndroid Build Coastguard Worker        actual = m_s(x)
1887*da0073e9SAndroid Build Coastguard Worker
1888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker    def test_freeze_non_interface_module_swap(self):
1891*da0073e9SAndroid Build Coastguard Worker        class InnerModule(torch.nn.Module):
1892*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
1893*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1894*da0073e9SAndroid Build Coastguard Worker                self.x = x
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1897*da0073e9SAndroid Build Coastguard Worker                return inp.relu() + self.x
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
1900*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1901*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1902*da0073e9SAndroid Build Coastguard Worker                self.option1 = InnerModule(torch.rand((2, 2)))
1903*da0073e9SAndroid Build Coastguard Worker                self.option2 = InnerModule(torch.rand((2, 2)))
1904*da0073e9SAndroid Build Coastguard Worker                self.impl = self.option1
1905*da0073e9SAndroid Build Coastguard Worker                self.idx = 0
1906*da0073e9SAndroid Build Coastguard Worker
1907*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1908*da0073e9SAndroid Build Coastguard Worker                self.idx += 1
1909*da0073e9SAndroid Build Coastguard Worker                if self.idx % 2 == 1:
1910*da0073e9SAndroid Build Coastguard Worker                    self.impl = self.option1
1911*da0073e9SAndroid Build Coastguard Worker                else:
1912*da0073e9SAndroid Build Coastguard Worker                    self.impl = self.option2
1913*da0073e9SAndroid Build Coastguard Worker                return self.impl(x)
1914*da0073e9SAndroid Build Coastguard Worker
1915*da0073e9SAndroid Build Coastguard Worker        unfrozen = WrapperModule()
1916*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(unfrozen)
1917*da0073e9SAndroid Build Coastguard Worker        m.eval()
1918*da0073e9SAndroid Build Coastguard Worker        m_frozen = torch.jit.freeze(m)
1919*da0073e9SAndroid Build Coastguard Worker
1920*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
1921*da0073e9SAndroid Build Coastguard Worker        expected = unfrozen(x)
1922*da0073e9SAndroid Build Coastguard Worker        actual = m_frozen(x)
1923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
1926*da0073e9SAndroid Build Coastguard Worker    def test_freeze_interface_within_object(self):
1927*da0073e9SAndroid Build Coastguard Worker        # I don't think there's any way to create a plain python object that
1928*da0073e9SAndroid Build Coastguard Worker        # contains a torch.nn.Module inside it, but just in case... I'm not
1929*da0073e9SAndroid Build Coastguard Worker        # sure freezing would handle this case correctly, so marking as xfail
1930*da0073e9SAndroid Build Coastguard Worker        # so that if this ever _does_ start working someone will need to
1931*da0073e9SAndroid Build Coastguard Worker        # investigate to make sure this is handled correctly.
1932*da0073e9SAndroid Build Coastguard Worker        class MyIface(torch.nn.Module):
1933*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1934*da0073e9SAndroid Build Coastguard Worker                pass
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker        class MyImpl(torch.nn.Module):
1937*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1938*da0073e9SAndroid Build Coastguard Worker                return inp.sin()
1939*da0073e9SAndroid Build Coastguard Worker
1940*da0073e9SAndroid Build Coastguard Worker        class MyObject:
1941*da0073e9SAndroid Build Coastguard Worker            impl: MyIface
1942*da0073e9SAndroid Build Coastguard Worker
1943*da0073e9SAndroid Build Coastguard Worker            def run(self, x):
1944*da0073e9SAndroid Build Coastguard Worker                return self.impl(x)
1945*da0073e9SAndroid Build Coastguard Worker
1946*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
1947*da0073e9SAndroid Build Coastguard Worker            impl: MyObject
1948*da0073e9SAndroid Build Coastguard Worker
1949*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1950*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1951*da0073e9SAndroid Build Coastguard Worker                self.impl = MyObject()
1952*da0073e9SAndroid Build Coastguard Worker                self.impl.impl = MyImpl()
1953*da0073e9SAndroid Build Coastguard Worker
1954*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
1955*da0073e9SAndroid Build Coastguard Worker                return self.impl(x)
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker        unfrozen = WrapperModule()
1958*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(unfrozen)
1959*da0073e9SAndroid Build Coastguard Worker        m.eval()
1960*da0073e9SAndroid Build Coastguard Worker        m_frozen = torch.jit.freeze(m)
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
1963*da0073e9SAndroid Build Coastguard Worker        expected = unfrozen(x)
1964*da0073e9SAndroid Build Coastguard Worker        actual = m_frozen(x)
1965*da0073e9SAndroid Build Coastguard Worker        self.expectEqual(expected, actual)
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker    def test_freeze_non_module_class_getattr(self):
1968*da0073e9SAndroid Build Coastguard Worker        class BoxCoder:
1969*da0073e9SAndroid Build Coastguard Worker            def __init__(self, bbox_xform_clip):
1970*da0073e9SAndroid Build Coastguard Worker                # type: (float) -> None
1971*da0073e9SAndroid Build Coastguard Worker                self.bbox_xform_clip = bbox_xform_clip
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker            def decode(self, input):
1974*da0073e9SAndroid Build Coastguard Worker                return input * self.bbox_xform_clip
1975*da0073e9SAndroid Build Coastguard Worker
1976*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1977*da0073e9SAndroid Build Coastguard Worker            __annotations__ = {
1978*da0073e9SAndroid Build Coastguard Worker                "box_coder": BoxCoder,
1979*da0073e9SAndroid Build Coastguard Worker            }
1980*da0073e9SAndroid Build Coastguard Worker
1981*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1982*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1983*da0073e9SAndroid Build Coastguard Worker                self.box_coder = BoxCoder(50.0)
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
1986*da0073e9SAndroid Build Coastguard Worker                return self.box_coder.decode(input)
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker        model = MyModule()
1989*da0073e9SAndroid Build Coastguard Worker        model.eval()
1990*da0073e9SAndroid Build Coastguard Worker        script_model = torch.jit.freeze(torch.jit.script(model))
1991*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn([4, 4])
1992*da0073e9SAndroid Build Coastguard Worker        output_eager = model(inp)
1993*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model(inp), script_model(inp))
1994*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("GetAttr").run(script_model.graph)
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_tupleoutput_submodule(self):
1997*da0073e9SAndroid Build Coastguard Worker        class SubModule(nn.Module):
1998*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1999*da0073e9SAndroid Build Coastguard Worker                return (x + 1, x + 2)
2000*da0073e9SAndroid Build Coastguard Worker
2001*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
2002*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2003*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2004*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2007*da0073e9SAndroid Build Coastguard Worker                y1, y2 = self.sub(x)
2008*da0073e9SAndroid Build Coastguard Worker                return y1 + y2
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(TestModule())
2011*da0073e9SAndroid Build Coastguard Worker        m = m.eval()
2012*da0073e9SAndroid Build Coastguard Worker        mf = torch.jit.freeze(m)
2013*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 2)
2014*da0073e9SAndroid Build Coastguard Worker        expected = m.forward(inp)
2015*da0073e9SAndroid Build Coastguard Worker        output = mf.forward(inp)
2016*da0073e9SAndroid Build Coastguard Worker        # Check if prim::TupleConstruct and prim::TupleUnpack
2017*da0073e9SAndroid Build Coastguard Worker        # Don't exist in frozen graph
2018*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::TupleConstruct").run(mf.graph)
2019*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::TupleUnpack").run(mf.graph)
2020*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, expected)
2021*da0073e9SAndroid Build Coastguard Worker
2022*da0073e9SAndroid Build Coastguard Worker    def test_freeze_module_with_call_method(self):
2023*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
2024*da0073e9SAndroid Build Coastguard Worker            def __init__(self, val):
2025*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2026*da0073e9SAndroid Build Coastguard Worker                self.param = nn.Parameter(val)
2027*da0073e9SAndroid Build Coastguard Worker
2028*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2029*da0073e9SAndroid Build Coastguard Worker                # this method will change during freezing
2030*da0073e9SAndroid Build Coastguard Worker                return x + self.param
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
2033*da0073e9SAndroid Build Coastguard Worker            def make_prediction(self, x):
2034*da0073e9SAndroid Build Coastguard Worker                y = x + x
2035*da0073e9SAndroid Build Coastguard Worker                return self.forward(y)
2036*da0073e9SAndroid Build Coastguard Worker
2037*da0073e9SAndroid Build Coastguard Worker        param = torch.rand([2, 2])
2038*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 2])
2039*da0073e9SAndroid Build Coastguard Worker
2040*da0073e9SAndroid Build Coastguard Worker        unscripted_mod = Mod(param)
2041*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.script(unscripted_mod)
2042*da0073e9SAndroid Build Coastguard Worker        mod.eval()
2043*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.freeze(mod, preserved_attrs=["make_prediction"])
2044*da0073e9SAndroid Build Coastguard Worker
2045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2046*da0073e9SAndroid Build Coastguard Worker            mod.forward(x), unscripted_mod.forward(x), atol=1e-5, rtol=1e-5
2047*da0073e9SAndroid Build Coastguard Worker        )
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard Worker
2050*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("somehow causing hanging during python shutdown")
2051*da0073e9SAndroid Build Coastguard Workerclass TestFrozenOptimizations(JitTestCase):
2052*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
2053*da0073e9SAndroid Build Coastguard Worker        super().setUp()
2054*da0073e9SAndroid Build Coastguard Worker        self.default_dtype = torch.get_default_dtype()
2055*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(torch.double)
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
2058*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(self.default_dtype)
2059*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
2060*da0073e9SAndroid Build Coastguard Worker
2061*da0073e9SAndroid Build Coastguard Worker    def test_conv_bn_folding(self):
2062*da0073e9SAndroid Build Coastguard Worker        conv_bias = [True, False]
2063*da0073e9SAndroid Build Coastguard Worker        module_pairs = [
2064*da0073e9SAndroid Build Coastguard Worker            (nn.Conv1d, nn.BatchNorm1d),
2065*da0073e9SAndroid Build Coastguard Worker            (nn.Conv2d, nn.BatchNorm2d),
2066*da0073e9SAndroid Build Coastguard Worker            (nn.Conv3d, nn.BatchNorm3d),
2067*da0073e9SAndroid Build Coastguard Worker        ]
2068*da0073e9SAndroid Build Coastguard Worker        use_tracing = [True, False]
2069*da0073e9SAndroid Build Coastguard Worker        bn_running_stats = [True, False]
2070*da0073e9SAndroid Build Coastguard Worker
2071*da0073e9SAndroid Build Coastguard Worker        for use_bias, modules, tracing, track_stats in product(
2072*da0073e9SAndroid Build Coastguard Worker            conv_bias, module_pairs, use_tracing, bn_running_stats
2073*da0073e9SAndroid Build Coastguard Worker        ):
2074*da0073e9SAndroid Build Coastguard Worker
2075*da0073e9SAndroid Build Coastguard Worker            class ConvBN(torch.nn.Module):
2076*da0073e9SAndroid Build Coastguard Worker                def __init__(self, in_channels, out_channels, **kwargs):
2077*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
2078*da0073e9SAndroid Build Coastguard Worker                    self.conv = modules[0](
2079*da0073e9SAndroid Build Coastguard Worker                        in_channels, out_channels, bias=use_bias, **kwargs
2080*da0073e9SAndroid Build Coastguard Worker                    )
2081*da0073e9SAndroid Build Coastguard Worker                    self.bn = modules[1](
2082*da0073e9SAndroid Build Coastguard Worker                        out_channels, eps=0.001, track_running_stats=track_stats
2083*da0073e9SAndroid Build Coastguard Worker                    )
2084*da0073e9SAndroid Build Coastguard Worker
2085*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
2086*da0073e9SAndroid Build Coastguard Worker                    x = self.conv(x)
2087*da0073e9SAndroid Build Coastguard Worker                    return self.bn(x)
2088*da0073e9SAndroid Build Coastguard Worker
2089*da0073e9SAndroid Build Coastguard Worker            mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
2090*da0073e9SAndroid Build Coastguard Worker            inps = [4, 3, 4]
2091*da0073e9SAndroid Build Coastguard Worker            if modules[0] == nn.Conv2d:
2092*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2093*da0073e9SAndroid Build Coastguard Worker            if modules[0] == nn.Conv3d:
2094*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2095*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2096*da0073e9SAndroid Build Coastguard Worker
2097*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(inps)
2098*da0073e9SAndroid Build Coastguard Worker
2099*da0073e9SAndroid Build Coastguard Worker            if tracing:
2100*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.trace(mod_eager, (inp))
2101*da0073e9SAndroid Build Coastguard Worker            else:
2102*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.script(mod_eager)
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker            self.run_pass("inline", scripted_mod.graph)
2105*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", scripted_mod.graph)
2106*da0073e9SAndroid Build Coastguard Worker            self.run_pass("constant_propagation", scripted_mod.graph)
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("conv").check("batch").run(scripted_mod.graph)
2109*da0073e9SAndroid Build Coastguard Worker            # successfully no-ops with non-const inputs
2110*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
2111*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph)
2112*da0073e9SAndroid Build Coastguard Worker
2113*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.freeze(scripted_mod)
2114*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
2115*da0073e9SAndroid Build Coastguard Worker            if track_stats:
2116*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("conv").check_not("aten::batch_norm").run(
2117*da0073e9SAndroid Build Coastguard Worker                    scripted_mod.graph
2118*da0073e9SAndroid Build Coastguard Worker                )
2119*da0073e9SAndroid Build Coastguard Worker            else:
2120*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("conv").check("aten::batch_norm").run(
2121*da0073e9SAndroid Build Coastguard Worker                    scripted_mod.graph
2122*da0073e9SAndroid Build Coastguard Worker                )
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2125*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker    def test_conv_bn_folding_not_forward(self):
2128*da0073e9SAndroid Build Coastguard Worker        class ConvBN(torch.nn.Module):
2129*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_channels, out_channels, **kwargs):
2130*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2131*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(
2132*da0073e9SAndroid Build Coastguard Worker                    in_channels, out_channels, bias=True, **kwargs
2133*da0073e9SAndroid Build Coastguard Worker                )
2134*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
2135*da0073e9SAndroid Build Coastguard Worker                self.amt = 3.2
2136*da0073e9SAndroid Build Coastguard Worker
2137*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2138*da0073e9SAndroid Build Coastguard Worker                x = self.conv(x)
2139*da0073e9SAndroid Build Coastguard Worker                return self.bn(x)
2140*da0073e9SAndroid Build Coastguard Worker
2141*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
2142*da0073e9SAndroid Build Coastguard Worker            def make_prediction(self, x):
2143*da0073e9SAndroid Build Coastguard Worker                return self.forward(x) + self.amt
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker        mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
2146*da0073e9SAndroid Build Coastguard Worker        scripted_mod = torch.jit.script(mod_eager)
2147*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_inline(scripted_mod.make_prediction.graph)
2148*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("conv").check("aten::batch_norm").run(
2149*da0073e9SAndroid Build Coastguard Worker            scripted_mod.make_prediction.graph
2150*da0073e9SAndroid Build Coastguard Worker        )
2151*da0073e9SAndroid Build Coastguard Worker
2152*da0073e9SAndroid Build Coastguard Worker        # _jit_pass_optimize_frozen_graph should not be called on non-method attributes (e.g. "amt")
2153*da0073e9SAndroid Build Coastguard Worker        scripted_mod = torch.jit.freeze(
2154*da0073e9SAndroid Build Coastguard Worker            scripted_mod, preserved_attrs=["make_prediction", "amt"]
2155*da0073e9SAndroid Build Coastguard Worker        )
2156*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("conv").check_not("aten::batch_norm").run(
2157*da0073e9SAndroid Build Coastguard Worker            scripted_mod.make_prediction.graph
2158*da0073e9SAndroid Build Coastguard Worker        )
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker    # During freezing this creates tensors constants that are attached to the frozen graph,
2161*da0073e9SAndroid Build Coastguard Worker    # which is then kept alive by the compilation unit (which causes a leak)
2162*da0073e9SAndroid Build Coastguard Worker    @skipCUDAMemoryLeakCheckIf(True)
2163*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2164*da0073e9SAndroid Build Coastguard Worker    def test_conv_bn_folding_autocast_scenario_cuda(self):
2165*da0073e9SAndroid Build Coastguard Worker        # CUDA conv takes input tensors which must all be the same dtype,
2166*da0073e9SAndroid Build Coastguard Worker        # which can cause issues if folding produces inputs of different dtypes.
2167*da0073e9SAndroid Build Coastguard Worker
2168*da0073e9SAndroid Build Coastguard Worker        class ConvBN(torch.nn.Module):
2169*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_channels, out_channels, **kwargs):
2170*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2171*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(
2172*da0073e9SAndroid Build Coastguard Worker                    in_channels, out_channels, bias=False, dtype=torch.half, **kwargs
2173*da0073e9SAndroid Build Coastguard Worker                )
2174*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(
2175*da0073e9SAndroid Build Coastguard Worker                    out_channels, eps=0.001, dtype=torch.float
2176*da0073e9SAndroid Build Coastguard Worker                )
2177*da0073e9SAndroid Build Coastguard Worker
2178*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2179*da0073e9SAndroid Build Coastguard Worker                return self.bn(self.conv(x))
2180*da0073e9SAndroid Build Coastguard Worker
2181*da0073e9SAndroid Build Coastguard Worker        mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).cuda().eval()
2182*da0073e9SAndroid Build Coastguard Worker        scripted_mod = torch.jit.script(mod_eager)
2183*da0073e9SAndroid Build Coastguard Worker        scripted_mod = torch.jit.freeze(scripted_mod)
2184*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph)
2185*da0073e9SAndroid Build Coastguard Worker        conv_node = scripted_mod.graph.findNode("aten::conv2d", True)
2186*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(conv_node is not None)
2187*da0073e9SAndroid Build Coastguard Worker        bias_input = conv_node.namedInput("bias")
2188*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bias_input is not None)
2189*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bias_input.type().dtype() == torch.half)
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3, 32, 32), dtype=torch.half).cuda()
2192*da0073e9SAndroid Build Coastguard Worker
2193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2195*da0073e9SAndroid Build Coastguard Worker
2196*da0073e9SAndroid Build Coastguard Worker    def test_conv_add_folding(self):
2197*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2198*da0073e9SAndroid Build Coastguard Worker        def test_conv_fusion(
2199*da0073e9SAndroid Build Coastguard Worker            use_bias, module, tracing, op, scalar, add_tensor, expect_success
2200*da0073e9SAndroid Build Coastguard Worker        ):
2201*da0073e9SAndroid Build Coastguard Worker            class ConvOp(torch.nn.Module):
2202*da0073e9SAndroid Build Coastguard Worker                __constants__ = ["use_scalar"]
2203*da0073e9SAndroid Build Coastguard Worker
2204*da0073e9SAndroid Build Coastguard Worker                def __init__(self, in_channels, out_channels, tensor=None, **kwargs):
2205*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
2206*da0073e9SAndroid Build Coastguard Worker                    self.conv = module(
2207*da0073e9SAndroid Build Coastguard Worker                        in_channels, out_channels, bias=use_bias, **kwargs
2208*da0073e9SAndroid Build Coastguard Worker                    )
2209*da0073e9SAndroid Build Coastguard Worker                    self.conv2 = module(
2210*da0073e9SAndroid Build Coastguard Worker                        in_channels, out_channels, bias=use_bias, **kwargs
2211*da0073e9SAndroid Build Coastguard Worker                    )
2212*da0073e9SAndroid Build Coastguard Worker                    self.use_scalar = scalar
2213*da0073e9SAndroid Build Coastguard Worker                    tensor_size = [1 for _ in range(self.conv.weight.ndim)]
2214*da0073e9SAndroid Build Coastguard Worker                    tensor_size[1] = self.conv.weight.size(0)
2215*da0073e9SAndroid Build Coastguard Worker                    self.tensor = (
2216*da0073e9SAndroid Build Coastguard Worker                        add_tensor
2217*da0073e9SAndroid Build Coastguard Worker                        if add_tensor is not None
2218*da0073e9SAndroid Build Coastguard Worker                        else torch.rand(tensor_size)
2219*da0073e9SAndroid Build Coastguard Worker                    )
2220*da0073e9SAndroid Build Coastguard Worker                    self.op = op
2221*da0073e9SAndroid Build Coastguard Worker
2222*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
2223*da0073e9SAndroid Build Coastguard Worker                    x = self.conv(x)
2224*da0073e9SAndroid Build Coastguard Worker                    if self.use_scalar:
2225*da0073e9SAndroid Build Coastguard Worker                        return self.op(x, 2.0)
2226*da0073e9SAndroid Build Coastguard Worker                    else:
2227*da0073e9SAndroid Build Coastguard Worker                        return self.op(x, self.tensor)
2228*da0073e9SAndroid Build Coastguard Worker
2229*da0073e9SAndroid Build Coastguard Worker            mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval()
2230*da0073e9SAndroid Build Coastguard Worker
2231*da0073e9SAndroid Build Coastguard Worker            inps = [4, 3, 4]
2232*da0073e9SAndroid Build Coastguard Worker            if module == nn.Conv2d:
2233*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2234*da0073e9SAndroid Build Coastguard Worker            if module == nn.Conv3d:
2235*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2236*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2237*da0073e9SAndroid Build Coastguard Worker
2238*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(inps)
2239*da0073e9SAndroid Build Coastguard Worker
2240*da0073e9SAndroid Build Coastguard Worker            if tracing:
2241*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.trace(mod_eager, (inp,))
2242*da0073e9SAndroid Build Coastguard Worker            else:
2243*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.script(mod_eager)
2244*da0073e9SAndroid Build Coastguard Worker
2245*da0073e9SAndroid Build Coastguard Worker            self.run_pass("inline", scripted_mod.graph)
2246*da0073e9SAndroid Build Coastguard Worker            op_str = "aten::" + op.__name__
2247*da0073e9SAndroid Build Coastguard Worker
2248*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
2249*da0073e9SAndroid Build Coastguard Worker            # successively no-ops with non-const inputs
2250*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
2251*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)
2252*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
2253*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.freeze(scripted_mod)
2254*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
2255*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)
2256*da0073e9SAndroid Build Coastguard Worker
2257*da0073e9SAndroid Build Coastguard Worker            if expect_success:
2258*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph)
2259*da0073e9SAndroid Build Coastguard Worker            else:
2260*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
2261*da0073e9SAndroid Build Coastguard Worker
2262*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2263*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2264*da0073e9SAndroid Build Coastguard Worker
2265*da0073e9SAndroid Build Coastguard Worker        conv_bias = [True, False]
2266*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
2267*da0073e9SAndroid Build Coastguard Worker        use_tracing = [False, True]
2268*da0073e9SAndroid Build Coastguard Worker        use_scalar = [False, True]
2269*da0073e9SAndroid Build Coastguard Worker        ops = [torch.add, torch.sub, torch.mul, torch.div]
2270*da0073e9SAndroid Build Coastguard Worker
2271*da0073e9SAndroid Build Coastguard Worker        for use_bias, module, tracing, pytorch_op, scalar in product(
2272*da0073e9SAndroid Build Coastguard Worker            conv_bias, modules, use_tracing, ops, use_scalar
2273*da0073e9SAndroid Build Coastguard Worker        ):
2274*da0073e9SAndroid Build Coastguard Worker            test_conv_fusion(
2275*da0073e9SAndroid Build Coastguard Worker                use_bias,
2276*da0073e9SAndroid Build Coastguard Worker                module,
2277*da0073e9SAndroid Build Coastguard Worker                tracing,
2278*da0073e9SAndroid Build Coastguard Worker                pytorch_op,
2279*da0073e9SAndroid Build Coastguard Worker                scalar,
2280*da0073e9SAndroid Build Coastguard Worker                add_tensor=None,
2281*da0073e9SAndroid Build Coastguard Worker                expect_success=True,
2282*da0073e9SAndroid Build Coastguard Worker            )
2283*da0073e9SAndroid Build Coastguard Worker
2284*da0073e9SAndroid Build Coastguard Worker        for use_bias, pytorch_op in product(conv_bias, ops):
2285*da0073e9SAndroid Build Coastguard Worker            # broadcasting add
2286*da0073e9SAndroid Build Coastguard Worker            test_conv_fusion(
2287*da0073e9SAndroid Build Coastguard Worker                use_bias,
2288*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d,
2289*da0073e9SAndroid Build Coastguard Worker                False,
2290*da0073e9SAndroid Build Coastguard Worker                pytorch_op,
2291*da0073e9SAndroid Build Coastguard Worker                False,
2292*da0073e9SAndroid Build Coastguard Worker                add_tensor=torch.rand(32, 1, 32),
2293*da0073e9SAndroid Build Coastguard Worker                expect_success=False,
2294*da0073e9SAndroid Build Coastguard Worker            )
2295*da0073e9SAndroid Build Coastguard Worker
2296*da0073e9SAndroid Build Coastguard Worker            # broadcasting add
2297*da0073e9SAndroid Build Coastguard Worker            test_conv_fusion(
2298*da0073e9SAndroid Build Coastguard Worker                use_bias,
2299*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d,
2300*da0073e9SAndroid Build Coastguard Worker                False,
2301*da0073e9SAndroid Build Coastguard Worker                pytorch_op,
2302*da0073e9SAndroid Build Coastguard Worker                False,
2303*da0073e9SAndroid Build Coastguard Worker                add_tensor=torch.rand(1, 1),
2304*da0073e9SAndroid Build Coastguard Worker                expect_success=True,
2305*da0073e9SAndroid Build Coastguard Worker            )
2306*da0073e9SAndroid Build Coastguard Worker
2307*da0073e9SAndroid Build Coastguard Worker            # add with different dtype
2308*da0073e9SAndroid Build Coastguard Worker            test_conv_fusion(
2309*da0073e9SAndroid Build Coastguard Worker                use_bias,
2310*da0073e9SAndroid Build Coastguard Worker                nn.Conv2d,
2311*da0073e9SAndroid Build Coastguard Worker                False,
2312*da0073e9SAndroid Build Coastguard Worker                pytorch_op,
2313*da0073e9SAndroid Build Coastguard Worker                False,
2314*da0073e9SAndroid Build Coastguard Worker                add_tensor=torch.tensor([2]).to(torch.int),
2315*da0073e9SAndroid Build Coastguard Worker                expect_success=True,
2316*da0073e9SAndroid Build Coastguard Worker            )
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker    def test_conv_mul_add_bn(self):
2319*da0073e9SAndroid Build Coastguard Worker        class Conv_Mul_Add_Bn(nn.Module):
2320*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_channels, out_channels, **kwargs):
2321*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2322*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
2323*da0073e9SAndroid Build Coastguard Worker                self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
2324*da0073e9SAndroid Build Coastguard Worker                self.tensor1 = torch.tensor(2.2)
2325*da0073e9SAndroid Build Coastguard Worker                self.tensor2 = torch.tensor(2)
2326*da0073e9SAndroid Build Coastguard Worker
2327*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2328*da0073e9SAndroid Build Coastguard Worker                return self.bn(
2329*da0073e9SAndroid Build Coastguard Worker                    torch.add(torch.mul(self.conv(x), self.tensor1), self.tensor2)
2330*da0073e9SAndroid Build Coastguard Worker                )
2331*da0073e9SAndroid Build Coastguard Worker
2332*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(8, 3, 64, 64)
2333*da0073e9SAndroid Build Coastguard Worker        model = Conv_Mul_Add_Bn(3, 32, kernel_size=3, stride=1).eval()
2334*da0073e9SAndroid Build Coastguard Worker
2335*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2336*da0073e9SAndroid Build Coastguard Worker            result = model(input)
2337*da0073e9SAndroid Build Coastguard Worker            traced_model = torch.jit.trace(model, input).eval()
2338*da0073e9SAndroid Build Coastguard Worker            traced_model = torch.jit.freeze(traced_model)
2339*da0073e9SAndroid Build Coastguard Worker            tresult = traced_model(input)
2340*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, tresult)
2341*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("conv").check_not("aten::batch_norm").run(
2342*da0073e9SAndroid Build Coastguard Worker                traced_model.graph
2343*da0073e9SAndroid Build Coastguard Worker            )
2344*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("conv").check_not("aten::add").run(traced_model.graph)
2345*da0073e9SAndroid Build Coastguard Worker
2346*da0073e9SAndroid Build Coastguard Worker    def test_linear_bn_folding(self):
2347*da0073e9SAndroid Build Coastguard Worker        module_pairs = [
2348*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm1d),
2349*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm2d),
2350*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm3d),
2351*da0073e9SAndroid Build Coastguard Worker        ]
2352*da0073e9SAndroid Build Coastguard Worker        use_tracing = [True, False]
2353*da0073e9SAndroid Build Coastguard Worker        bn_running_stats = [True, False]
2354*da0073e9SAndroid Build Coastguard Worker
2355*da0073e9SAndroid Build Coastguard Worker        for modules, tracing, track_stats in product(
2356*da0073e9SAndroid Build Coastguard Worker            module_pairs, use_tracing, bn_running_stats
2357*da0073e9SAndroid Build Coastguard Worker        ):
2358*da0073e9SAndroid Build Coastguard Worker
2359*da0073e9SAndroid Build Coastguard Worker            class LinearBN(torch.nn.Module):
2360*da0073e9SAndroid Build Coastguard Worker                def __init__(self, in_features, out_features):
2361*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
2362*da0073e9SAndroid Build Coastguard Worker                    self.linear = modules[0](in_features, out_features)
2363*da0073e9SAndroid Build Coastguard Worker                    self.bn = modules[1](
2364*da0073e9SAndroid Build Coastguard Worker                        out_features, eps=0.001, track_running_stats=track_stats
2365*da0073e9SAndroid Build Coastguard Worker                    )
2366*da0073e9SAndroid Build Coastguard Worker
2367*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
2368*da0073e9SAndroid Build Coastguard Worker                    x = self.linear(x)
2369*da0073e9SAndroid Build Coastguard Worker                    return self.bn(x)
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker            mod_eager = LinearBN(32, 32).eval()
2372*da0073e9SAndroid Build Coastguard Worker
2373*da0073e9SAndroid Build Coastguard Worker            inps = [3, 32]
2374*da0073e9SAndroid Build Coastguard Worker            if modules[1] == nn.BatchNorm2d:
2375*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2376*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2377*da0073e9SAndroid Build Coastguard Worker            if modules[1] == nn.BatchNorm3d:
2378*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2379*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2380*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2381*da0073e9SAndroid Build Coastguard Worker
2382*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(inps)
2383*da0073e9SAndroid Build Coastguard Worker
2384*da0073e9SAndroid Build Coastguard Worker            if tracing:
2385*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.trace(mod_eager, (inp))
2386*da0073e9SAndroid Build Coastguard Worker            else:
2387*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.script(mod_eager)
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker            self.run_pass("inline", scripted_mod.graph)
2390*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", scripted_mod.graph)
2391*da0073e9SAndroid Build Coastguard Worker            self.run_pass("constant_propagation", scripted_mod.graph)
2392*da0073e9SAndroid Build Coastguard Worker
2393*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("linear").check("batch").run(scripted_mod.graph)
2394*da0073e9SAndroid Build Coastguard Worker            # successfully no-ops with non-const inputs
2395*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2396*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("linear").check("aten::batch_norm").run(
2397*da0073e9SAndroid Build Coastguard Worker                scripted_mod.graph
2398*da0073e9SAndroid Build Coastguard Worker            )
2399*da0073e9SAndroid Build Coastguard Worker
2400*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.freeze(scripted_mod)
2401*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2402*da0073e9SAndroid Build Coastguard Worker            if track_stats:
2403*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("linear").check_not("aten::batch_norm").run(
2404*da0073e9SAndroid Build Coastguard Worker                    scripted_mod.graph
2405*da0073e9SAndroid Build Coastguard Worker                )
2406*da0073e9SAndroid Build Coastguard Worker            else:
2407*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("linear").check("aten::batch_norm").run(
2408*da0073e9SAndroid Build Coastguard Worker                    scripted_mod.graph
2409*da0073e9SAndroid Build Coastguard Worker                )
2410*da0073e9SAndroid Build Coastguard Worker
2411*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2412*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2413*da0073e9SAndroid Build Coastguard Worker
2414*da0073e9SAndroid Build Coastguard Worker    def test_bn_not_broadcast_with_linear(self):
2415*da0073e9SAndroid Build Coastguard Worker        module_pairs = [
2416*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm1d),
2417*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm2d),
2418*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm3d),
2419*da0073e9SAndroid Build Coastguard Worker        ]
2420*da0073e9SAndroid Build Coastguard Worker        use_tracing = [True, False]
2421*da0073e9SAndroid Build Coastguard Worker        linear_in = 3
2422*da0073e9SAndroid Build Coastguard Worker        # (linear_out, bn_in)
2423*da0073e9SAndroid Build Coastguard Worker        # case 1: linear_out < bn_in
2424*da0073e9SAndroid Build Coastguard Worker        # case 2: linear_out > bn_in
2425*da0073e9SAndroid Build Coastguard Worker        # case 3: linear_out != bn_in && linear_out = 1
2426*da0073e9SAndroid Build Coastguard Worker        dims = [(2, 4), (4, 2), (1, 2)]
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Worker        for modules, tracing, dim in product(module_pairs, use_tracing, dims):
2429*da0073e9SAndroid Build Coastguard Worker            linear_out, bn_in = dim[0], dim[1]
2430*da0073e9SAndroid Build Coastguard Worker
2431*da0073e9SAndroid Build Coastguard Worker            linear = modules[0](linear_in, linear_out)
2432*da0073e9SAndroid Build Coastguard Worker            bn = modules[1](bn_in)
2433*da0073e9SAndroid Build Coastguard Worker            mod_eager = nn.Sequential(linear, bn).eval()
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker            N, C = 3, bn_in
2436*da0073e9SAndroid Build Coastguard Worker            input_shape = [N, C]
2437*da0073e9SAndroid Build Coastguard Worker            if modules[1] == nn.BatchNorm1d:
2438*da0073e9SAndroid Build Coastguard Worker                H = linear_in
2439*da0073e9SAndroid Build Coastguard Worker                input_shape.append(H)
2440*da0073e9SAndroid Build Coastguard Worker            elif modules[1] == nn.BatchNorm2d:
2441*da0073e9SAndroid Build Coastguard Worker                H, W = 4, linear_in
2442*da0073e9SAndroid Build Coastguard Worker                input_shape.append(H)
2443*da0073e9SAndroid Build Coastguard Worker                input_shape.append(W)
2444*da0073e9SAndroid Build Coastguard Worker            elif modules[1] == nn.BatchNorm3d:
2445*da0073e9SAndroid Build Coastguard Worker                D, H, W = 4, 4, linear_in
2446*da0073e9SAndroid Build Coastguard Worker                input_shape.append(D)
2447*da0073e9SAndroid Build Coastguard Worker                input_shape.append(H)
2448*da0073e9SAndroid Build Coastguard Worker                input_shape.append(W)
2449*da0073e9SAndroid Build Coastguard Worker
2450*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(input_shape)
2451*da0073e9SAndroid Build Coastguard Worker
2452*da0073e9SAndroid Build Coastguard Worker            if tracing:
2453*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.trace(mod_eager, (inp))
2454*da0073e9SAndroid Build Coastguard Worker            else:
2455*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.script(mod_eager)
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker            self.run_pass("inline", scripted_mod.graph)
2458*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", scripted_mod.graph)
2459*da0073e9SAndroid Build Coastguard Worker            self.run_pass("constant_propagation", scripted_mod.graph)
2460*da0073e9SAndroid Build Coastguard Worker
2461*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("linear").check("batch").run(scripted_mod.graph)
2462*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2463*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("linear").check("aten::batch_norm").run(
2464*da0073e9SAndroid Build Coastguard Worker                scripted_mod.graph
2465*da0073e9SAndroid Build Coastguard Worker            )
2466*da0073e9SAndroid Build Coastguard Worker
2467*da0073e9SAndroid Build Coastguard Worker            frozen_mod = torch.jit.freeze(scripted_mod)
2468*da0073e9SAndroid Build Coastguard Worker            self.run_pass("fold_frozen_linear_bn", frozen_mod.graph)
2469*da0073e9SAndroid Build Coastguard Worker            # successfully skipped folding
2470*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("linear").check("aten::batch_norm").run(frozen_mod.graph)
2471*da0073e9SAndroid Build Coastguard Worker
2472*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), frozen_mod(inp))
2473*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(inp), frozen_mod(inp))
2474*da0073e9SAndroid Build Coastguard Worker
2475*da0073e9SAndroid Build Coastguard Worker            # successfully failed folding
2476*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2477*da0073e9SAndroid Build Coastguard Worker                AssertionError,
2478*da0073e9SAndroid Build Coastguard Worker                "To fuse, linear.out_features == bn.num_features or bn.num_features == 1",
2479*da0073e9SAndroid Build Coastguard Worker            ):
2480*da0073e9SAndroid Build Coastguard Worker                nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
2481*da0073e9SAndroid Build Coastguard Worker
2482*da0073e9SAndroid Build Coastguard Worker    @skipCUDAMemoryLeakCheckIf(True)
2483*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2484*da0073e9SAndroid Build Coastguard Worker    def test_linear_bn_folding_autocast_scenario_cuda(self):
2485*da0073e9SAndroid Build Coastguard Worker        module_pairs = [
2486*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm1d),
2487*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm2d),
2488*da0073e9SAndroid Build Coastguard Worker            (nn.Linear, nn.BatchNorm3d),
2489*da0073e9SAndroid Build Coastguard Worker        ]
2490*da0073e9SAndroid Build Coastguard Worker        use_tracing = [True, False]
2491*da0073e9SAndroid Build Coastguard Worker        bn_running_stats = [True, False]
2492*da0073e9SAndroid Build Coastguard Worker
2493*da0073e9SAndroid Build Coastguard Worker        for modules, tracing, track_stats in product(
2494*da0073e9SAndroid Build Coastguard Worker            module_pairs, use_tracing, bn_running_stats
2495*da0073e9SAndroid Build Coastguard Worker        ):
2496*da0073e9SAndroid Build Coastguard Worker
2497*da0073e9SAndroid Build Coastguard Worker            class LinearBN(torch.nn.Module):
2498*da0073e9SAndroid Build Coastguard Worker                def __init__(self, in_features, out_features):
2499*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
2500*da0073e9SAndroid Build Coastguard Worker                    self.linear = modules[0](
2501*da0073e9SAndroid Build Coastguard Worker                        in_features, out_features, bias=False, dtype=torch.half
2502*da0073e9SAndroid Build Coastguard Worker                    )
2503*da0073e9SAndroid Build Coastguard Worker                    self.bn = modules[1](out_features, eps=0.001, dtype=torch.float)
2504*da0073e9SAndroid Build Coastguard Worker
2505*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
2506*da0073e9SAndroid Build Coastguard Worker                    x = self.linear(x)
2507*da0073e9SAndroid Build Coastguard Worker                    return self.bn(x)
2508*da0073e9SAndroid Build Coastguard Worker
2509*da0073e9SAndroid Build Coastguard Worker            mod_eager = LinearBN(32, 32).cuda().eval()
2510*da0073e9SAndroid Build Coastguard Worker
2511*da0073e9SAndroid Build Coastguard Worker            inps = [3, 32]
2512*da0073e9SAndroid Build Coastguard Worker            if modules[1] == nn.BatchNorm2d:
2513*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2514*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2515*da0073e9SAndroid Build Coastguard Worker            if modules[1] == nn.BatchNorm3d:
2516*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2517*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2518*da0073e9SAndroid Build Coastguard Worker                inps.append(inps[-1])
2519*da0073e9SAndroid Build Coastguard Worker
2520*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(inps, dtype=torch.half).cuda()
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker            if tracing:
2523*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.trace(mod_eager, (x))
2524*da0073e9SAndroid Build Coastguard Worker            else:
2525*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.script(mod_eager)
2526*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.freeze(scripted_mod)
2527*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("linear").check_not("aten::batch_norm").run(
2528*da0073e9SAndroid Build Coastguard Worker                scripted_mod.graph
2529*da0073e9SAndroid Build Coastguard Worker            )
2530*da0073e9SAndroid Build Coastguard Worker            lin_node = scripted_mod.graph.findNode("aten::linear", True)
2531*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(lin_node is not None)
2532*da0073e9SAndroid Build Coastguard Worker            weight_input = lin_node.namedInput("weight")
2533*da0073e9SAndroid Build Coastguard Worker            bias_input = lin_node.namedInput("bias")
2534*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(bias_input is not None)
2535*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(weight_input.type().dtype() == torch.half)
2536*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(bias_input.type().dtype() == torch.half)
2537*da0073e9SAndroid Build Coastguard Worker
2538*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2539*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2540*da0073e9SAndroid Build Coastguard Worker
2541*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2542*da0073e9SAndroid Build Coastguard Worker    def test_linear_concat(self):
2543*da0073e9SAndroid Build Coastguard Worker        out_dimms = [[5, 10], [1, 5]]
2544*da0073e9SAndroid Build Coastguard Worker
2545*da0073e9SAndroid Build Coastguard Worker        for w1_dim, w2_dim in out_dimms:
2546*da0073e9SAndroid Build Coastguard Worker
2547*da0073e9SAndroid Build Coastguard Worker            class ModMultLinear(nn.Module):
2548*da0073e9SAndroid Build Coastguard Worker                def __init__(self, w1_dim, w2_dim):
2549*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
2550*da0073e9SAndroid Build Coastguard Worker                    self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2551*da0073e9SAndroid Build Coastguard Worker                    self.b1 = nn.Parameter(torch.rand([w1_dim]))
2552*da0073e9SAndroid Build Coastguard Worker                    self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2553*da0073e9SAndroid Build Coastguard Worker                    self.b2 = nn.Parameter(torch.rand([w2_dim]))
2554*da0073e9SAndroid Build Coastguard Worker
2555*da0073e9SAndroid Build Coastguard Worker                def forward(self, in_tensor1):
2556*da0073e9SAndroid Build Coastguard Worker                    res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2557*da0073e9SAndroid Build Coastguard Worker                    res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2)
2558*da0073e9SAndroid Build Coastguard Worker                    return res1, res2
2559*da0073e9SAndroid Build Coastguard Worker
2560*da0073e9SAndroid Build Coastguard Worker            mod_eager = ModMultLinear(w1_dim, w2_dim).eval()
2561*da0073e9SAndroid Build Coastguard Worker
2562*da0073e9SAndroid Build Coastguard Worker            test_val1 = torch.rand([50, 5])
2563*da0073e9SAndroid Build Coastguard Worker            self.check_linear_optimizations(mod_eager, 2, 1, (test_val1,))
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2566*da0073e9SAndroid Build Coastguard Worker    def test_linear_concat_complex(self):
2567*da0073e9SAndroid Build Coastguard Worker        """
2568*da0073e9SAndroid Build Coastguard Worker        Testing that the interleaving of multiple optimizations does not
2569*da0073e9SAndroid Build Coastguard Worker        cause errors, and gets optimized as expected
2570*da0073e9SAndroid Build Coastguard Worker        """
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker        class ModMultLinear(nn.Module):
2573*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2574*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2575*da0073e9SAndroid Build Coastguard Worker                w1_dim = 5
2576*da0073e9SAndroid Build Coastguard Worker                w2_dim = 10
2577*da0073e9SAndroid Build Coastguard Worker                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2578*da0073e9SAndroid Build Coastguard Worker                self.b1 = nn.Parameter(torch.rand([w1_dim]))
2579*da0073e9SAndroid Build Coastguard Worker                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2580*da0073e9SAndroid Build Coastguard Worker                self.b2 = nn.Parameter(torch.rand([w2_dim]))
2581*da0073e9SAndroid Build Coastguard Worker
2582*da0073e9SAndroid Build Coastguard Worker            def forward(self, in_tensor1):
2583*da0073e9SAndroid Build Coastguard Worker                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2584*da0073e9SAndroid Build Coastguard Worker                res3 = torch._C._nn.linear(res1, self.w2, self.b2)
2585*da0073e9SAndroid Build Coastguard Worker                res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2)
2586*da0073e9SAndroid Build Coastguard Worker                res4 = torch._C._nn.linear(res1, self.w1, self.b1)
2587*da0073e9SAndroid Build Coastguard Worker                return res2, res3, res4
2588*da0073e9SAndroid Build Coastguard Worker
2589*da0073e9SAndroid Build Coastguard Worker        mod_eager = ModMultLinear().eval()
2590*da0073e9SAndroid Build Coastguard Worker        test_val1 = torch.rand([50, 5])
2591*da0073e9SAndroid Build Coastguard Worker        self.check_linear_optimizations(mod_eager, 4, 2, (test_val1,))
2592*da0073e9SAndroid Build Coastguard Worker
2593*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2594*da0073e9SAndroid Build Coastguard Worker    def test_linear_concat_different_input(self):
2595*da0073e9SAndroid Build Coastguard Worker        """
2596*da0073e9SAndroid Build Coastguard Worker        There should be no change to the graph due to the optimization pass
2597*da0073e9SAndroid Build Coastguard Worker        due to the two input tensors being different
2598*da0073e9SAndroid Build Coastguard Worker        """
2599*da0073e9SAndroid Build Coastguard Worker
2600*da0073e9SAndroid Build Coastguard Worker        # Freezing requires that the graph be a module
2601*da0073e9SAndroid Build Coastguard Worker        class ModMultLinear(nn.Module):
2602*da0073e9SAndroid Build Coastguard Worker            def __init__(self, w1_dim, w2_dim):
2603*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2604*da0073e9SAndroid Build Coastguard Worker                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2605*da0073e9SAndroid Build Coastguard Worker                self.b1 = nn.Parameter(torch.rand([w1_dim]))
2606*da0073e9SAndroid Build Coastguard Worker                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2607*da0073e9SAndroid Build Coastguard Worker                self.b2 = nn.Parameter(torch.rand([w2_dim]))
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker            def forward(self, in_tensor1, in_tensor2):
2610*da0073e9SAndroid Build Coastguard Worker                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2611*da0073e9SAndroid Build Coastguard Worker                res2 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
2612*da0073e9SAndroid Build Coastguard Worker                return res1, res2
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker        mod_eager = ModMultLinear(5, 5).eval()
2615*da0073e9SAndroid Build Coastguard Worker        test_val1 = torch.rand([50, 5])
2616*da0073e9SAndroid Build Coastguard Worker        test_val2 = torch.rand([50, 5])
2617*da0073e9SAndroid Build Coastguard Worker        self.check_linear_optimizations(mod_eager, 2, 2, (test_val1, test_val2))
2618*da0073e9SAndroid Build Coastguard Worker
2619*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2620*da0073e9SAndroid Build Coastguard Worker    def test_linear_multiple_blocks(self):
2621*da0073e9SAndroid Build Coastguard Worker        class ModMultLinear(nn.Module):
2622*da0073e9SAndroid Build Coastguard Worker            def __init__(self, w1_dim, w2_dim):
2623*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2624*da0073e9SAndroid Build Coastguard Worker                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2625*da0073e9SAndroid Build Coastguard Worker                self.b1 = nn.Parameter(torch.rand([w1_dim]))
2626*da0073e9SAndroid Build Coastguard Worker                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2627*da0073e9SAndroid Build Coastguard Worker                self.b2 = nn.Parameter(torch.rand([w2_dim]))
2628*da0073e9SAndroid Build Coastguard Worker
2629*da0073e9SAndroid Build Coastguard Worker            def forward(self, in_tensor1, in_tensor2, cond: bool):
2630*da0073e9SAndroid Build Coastguard Worker                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2631*da0073e9SAndroid Build Coastguard Worker                if cond:
2632*da0073e9SAndroid Build Coastguard Worker                    res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
2633*da0073e9SAndroid Build Coastguard Worker                    res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
2634*da0073e9SAndroid Build Coastguard Worker                else:
2635*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError
2636*da0073e9SAndroid Build Coastguard Worker                res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
2637*da0073e9SAndroid Build Coastguard Worker                return res1, res2, res3, res4
2638*da0073e9SAndroid Build Coastguard Worker
2639*da0073e9SAndroid Build Coastguard Worker        mod_eager = ModMultLinear(5, 5).eval()
2640*da0073e9SAndroid Build Coastguard Worker        test_val1 = torch.rand([50, 5])
2641*da0073e9SAndroid Build Coastguard Worker        test_val2 = torch.rand([50, 5])
2642*da0073e9SAndroid Build Coastguard Worker        self.check_linear_optimizations(mod_eager, 4, 3, (test_val1, test_val2, True))
2643*da0073e9SAndroid Build Coastguard Worker
2644*da0073e9SAndroid Build Coastguard Worker    def check_linear_optimizations(
2645*da0073e9SAndroid Build Coastguard Worker        self, eager_mod, orig_linears, new_linears, test_vals
2646*da0073e9SAndroid Build Coastguard Worker    ):
2647*da0073e9SAndroid Build Coastguard Worker        for is_cuda in [False, True]:
2648*da0073e9SAndroid Build Coastguard Worker            if is_cuda:
2649*da0073e9SAndroid Build Coastguard Worker                mod_to_device = eager_mod.cuda()
2650*da0073e9SAndroid Build Coastguard Worker                test_vals_to_device = [
2651*da0073e9SAndroid Build Coastguard Worker                    t.cuda() if isinstance(t, torch.Tensor) else t for t in test_vals
2652*da0073e9SAndroid Build Coastguard Worker                ]
2653*da0073e9SAndroid Build Coastguard Worker            else:
2654*da0073e9SAndroid Build Coastguard Worker                mod_to_device = eager_mod
2655*da0073e9SAndroid Build Coastguard Worker                test_vals_to_device = test_vals
2656*da0073e9SAndroid Build Coastguard Worker
2657*da0073e9SAndroid Build Coastguard Worker            script_mod = torch.jit.script(mod_to_device)
2658*da0073e9SAndroid Build Coastguard Worker            op_graph = script_mod.graph
2659*da0073e9SAndroid Build Coastguard Worker
2660*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2661*da0073e9SAndroid Build Coastguard Worker                op_graph
2662*da0073e9SAndroid Build Coastguard Worker            )
2663*da0073e9SAndroid Build Coastguard Worker            # successively no-ops with non-const inputs
2664*da0073e9SAndroid Build Coastguard Worker            self.run_pass("concat_frozen_linear", op_graph)
2665*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2666*da0073e9SAndroid Build Coastguard Worker                op_graph
2667*da0073e9SAndroid Build Coastguard Worker            )
2668*da0073e9SAndroid Build Coastguard Worker
2669*da0073e9SAndroid Build Coastguard Worker            script_mod = torch.jit.freeze(script_mod)
2670*da0073e9SAndroid Build Coastguard Worker            op_graph = script_mod.graph
2671*da0073e9SAndroid Build Coastguard Worker            self.run_pass("concat_frozen_linear", op_graph)
2672*da0073e9SAndroid Build Coastguard Worker            if is_cuda:
2673*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count("aten::linear", new_linears, exactly=True).run(
2674*da0073e9SAndroid Build Coastguard Worker                    op_graph
2675*da0073e9SAndroid Build Coastguard Worker                )
2676*da0073e9SAndroid Build Coastguard Worker            else:
2677*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2678*da0073e9SAndroid Build Coastguard Worker                    op_graph
2679*da0073e9SAndroid Build Coastguard Worker                )
2680*da0073e9SAndroid Build Coastguard Worker
2681*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2682*da0073e9SAndroid Build Coastguard Worker                mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)
2683*da0073e9SAndroid Build Coastguard Worker            )
2684*da0073e9SAndroid Build Coastguard Worker
2685*da0073e9SAndroid Build Coastguard Worker    def test_optimize_freeze_module(self):
2686*da0073e9SAndroid Build Coastguard Worker        in_channels, out_channels = 3, 32
2687*da0073e9SAndroid Build Coastguard Worker        conv = torch.nn.Conv2d(
2688*da0073e9SAndroid Build Coastguard Worker            in_channels, out_channels, kernel_size=3, stride=2, bias=True
2689*da0073e9SAndroid Build Coastguard Worker        )
2690*da0073e9SAndroid Build Coastguard Worker        bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
2691*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Sequential(conv, bn)
2692*da0073e9SAndroid Build Coastguard Worker        # set optimize to False here, by default freezing runs run_frozen_optimizations
2693*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(
2694*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(mod.eval()), optimize_numerics=False
2695*da0073e9SAndroid Build Coastguard Worker        )
2696*da0073e9SAndroid Build Coastguard Worker        # inspect frozen mod
2697*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("batch_norm").run(frozen_mod.graph)
2698*da0073e9SAndroid Build Coastguard Worker        torch.jit.run_frozen_optimizations(frozen_mod)
2699*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("batch_norm").run(frozen_mod.graph)
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker        # run_frozen_optimizations should be run
2702*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
2703*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("batch_norm").run(frozen_mod.graph)
2704*da0073e9SAndroid Build Coastguard Worker
2705*da0073e9SAndroid Build Coastguard Worker    def test_freeze_remove_dropout(self):
2706*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
2707*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2708*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2709*da0073e9SAndroid Build Coastguard Worker                self.dropout = nn.Dropout(0.5)
2710*da0073e9SAndroid Build Coastguard Worker
2711*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2712*da0073e9SAndroid Build Coastguard Worker                return self.dropout(x)
2713*da0073e9SAndroid Build Coastguard Worker
2714*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.script(Net())
2715*da0073e9SAndroid Build Coastguard Worker        # inspect mod
2716*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_inline(mod.graph)
2717*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::dropout").run(mod.graph)
2718*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(mod.eval())
2719*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::dropout").run(frozen_mod.graph)
2720*da0073e9SAndroid Build Coastguard Worker
2721*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2)
2722*da0073e9SAndroid Build Coastguard Worker        output_s = mod.forward(input)
2723*da0073e9SAndroid Build Coastguard Worker        output_f = frozen_mod.forward(input)
2724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
2725*da0073e9SAndroid Build Coastguard Worker
2726*da0073e9SAndroid Build Coastguard Worker    def test_freeze_remove_feature_dropout(self):
2727*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
2728*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2729*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2730*da0073e9SAndroid Build Coastguard Worker                self.dropout = nn.Dropout2d(0.5)
2731*da0073e9SAndroid Build Coastguard Worker
2732*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2733*da0073e9SAndroid Build Coastguard Worker                return self.dropout(x)
2734*da0073e9SAndroid Build Coastguard Worker
2735*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.script(Net().eval())
2736*da0073e9SAndroid Build Coastguard Worker        # inspect mod
2737*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_inline(mod.graph)
2738*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::feature_dropout").run(mod.graph)
2739*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(mod)
2740*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph)
2741*da0073e9SAndroid Build Coastguard Worker
2742*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2, 1, 1)
2743*da0073e9SAndroid Build Coastguard Worker        output_s = mod.forward(input)
2744*da0073e9SAndroid Build Coastguard Worker        output_f = frozen_mod.forward(input)
2745*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_s, output_f)
2746*da0073e9SAndroid Build Coastguard Worker
2747*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2748*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2749*da0073e9SAndroid Build Coastguard Worker    )
2750*da0073e9SAndroid Build Coastguard Worker    def test_freeze_mkdlnn(self):
2751*da0073e9SAndroid Build Coastguard Worker        conv = torch.nn.Conv2d(3, 32, kernel_size=3, stride=2).eval().float()
2752*da0073e9SAndroid Build Coastguard Worker        convmkl = mkldnn_utils.to_mkldnn(conv)
2753*da0073e9SAndroid Build Coastguard Worker        out = torch.jit.freeze(torch.jit.script(convmkl.eval()))
2754*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([4, 3, 4, 4]).float()
2755*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out(inp.to_mkldnn()).to_dense(), conv(inp))
2756*da0073e9SAndroid Build Coastguard Worker
2757*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2758*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2759*da0073e9SAndroid Build Coastguard Worker    )
2760*da0073e9SAndroid Build Coastguard Worker    def test_conv_to_mkldnn(self):
2761*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
2762*da0073e9SAndroid Build Coastguard Worker            for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]):
2763*da0073e9SAndroid Build Coastguard Worker                mod = module(3, 32, kernel_size=3, stride=2).eval()
2764*da0073e9SAndroid Build Coastguard Worker                inps = [4, 3, 4]
2765*da0073e9SAndroid Build Coastguard Worker                if module == nn.Conv2d:
2766*da0073e9SAndroid Build Coastguard Worker                    inps.append(inps[-1])
2767*da0073e9SAndroid Build Coastguard Worker                if module == nn.Conv3d:
2768*da0073e9SAndroid Build Coastguard Worker                    inps.append(inps[-1])
2769*da0073e9SAndroid Build Coastguard Worker                    inps.append(inps[-1])
2770*da0073e9SAndroid Build Coastguard Worker
2771*da0073e9SAndroid Build Coastguard Worker                inp = torch.rand(inps)
2772*da0073e9SAndroid Build Coastguard Worker                if trace:
2773*da0073e9SAndroid Build Coastguard Worker                    scripted_mod = torch.jit.script(mod)
2774*da0073e9SAndroid Build Coastguard Worker                else:
2775*da0073e9SAndroid Build Coastguard Worker                    scripted_mod = torch.jit.trace(mod, (inp,))
2776*da0073e9SAndroid Build Coastguard Worker
2777*da0073e9SAndroid Build Coastguard Worker                self.run_pass("inline", scripted_mod.graph)
2778*da0073e9SAndroid Build Coastguard Worker
2779*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("conv").run(scripted_mod.graph)
2780*da0073e9SAndroid Build Coastguard Worker                # successfully no-ops with non-const inputs
2781*da0073e9SAndroid Build Coastguard Worker                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2782*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_not("to_mkldnn").run(scripted_mod.graph)
2783*da0073e9SAndroid Build Coastguard Worker
2784*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.freeze(scripted_mod)
2785*da0073e9SAndroid Build Coastguard Worker                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2786*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check(
2787*da0073e9SAndroid Build Coastguard Worker                    "to_dense"
2788*da0073e9SAndroid Build Coastguard Worker                ).run(scripted_mod.graph)
2789*da0073e9SAndroid Build Coastguard Worker
2790*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mod(inp), scripted_mod(inp))
2791*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mod(inp), scripted_mod(inp))
2792*da0073e9SAndroid Build Coastguard Worker
2793*da0073e9SAndroid Build Coastguard Worker    def test_linear_transpose(self):
2794*da0073e9SAndroid Build Coastguard Worker        class ModLinear(torch.nn.Module):
2795*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2796*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2797*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(30))
2798*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand([30, 20]))
2799*da0073e9SAndroid Build Coastguard Worker
2800*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2801*da0073e9SAndroid Build Coastguard Worker                return torch._C._nn.linear(x, self.weight, self.bias)
2802*da0073e9SAndroid Build Coastguard Worker
2803*da0073e9SAndroid Build Coastguard Worker        mod_eager = ModLinear().eval()
2804*da0073e9SAndroid Build Coastguard Worker        test_val = torch.rand([50, 20])
2805*da0073e9SAndroid Build Coastguard Worker        self.check_linear_optimizations_2(
2806*da0073e9SAndroid Build Coastguard Worker            mod_eager, 1, 0, "transpose_frozen_linear", (test_val,)
2807*da0073e9SAndroid Build Coastguard Worker        )
2808*da0073e9SAndroid Build Coastguard Worker
2809*da0073e9SAndroid Build Coastguard Worker    def test_linear_non_constant_weight(self):
2810*da0073e9SAndroid Build Coastguard Worker        class ModLinear(torch.nn.Module):
2811*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2812*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2813*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(30))
2814*da0073e9SAndroid Build Coastguard Worker
2815*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, weight):
2816*da0073e9SAndroid Build Coastguard Worker                return torch._C._nn.linear(x, weight, self.bias)
2817*da0073e9SAndroid Build Coastguard Worker
2818*da0073e9SAndroid Build Coastguard Worker        mod_eager = ModLinear().eval()
2819*da0073e9SAndroid Build Coastguard Worker        test_val = torch.rand([50, 20])
2820*da0073e9SAndroid Build Coastguard Worker        test_weight = torch.rand([30, 20])
2821*da0073e9SAndroid Build Coastguard Worker        self.check_linear_optimizations_2(
2822*da0073e9SAndroid Build Coastguard Worker            mod_eager, 1, 1, "transpose_frozen_linear", (test_val, test_weight)
2823*da0073e9SAndroid Build Coastguard Worker        )
2824*da0073e9SAndroid Build Coastguard Worker
2825*da0073e9SAndroid Build Coastguard Worker    def check_linear_optimizations_2(
2826*da0073e9SAndroid Build Coastguard Worker        self, eager_mod, orig_linears, new_linears, opt_pass, test_vals
2827*da0073e9SAndroid Build Coastguard Worker    ):
2828*da0073e9SAndroid Build Coastguard Worker        # TODO: merge with check_linear_optimizations once both diffs land
2829*da0073e9SAndroid Build Coastguard Worker        mod_to_device = eager_mod
2830*da0073e9SAndroid Build Coastguard Worker        test_vals_to_device = test_vals
2831*da0073e9SAndroid Build Coastguard Worker
2832*da0073e9SAndroid Build Coastguard Worker        script_mod = torch.jit.script(mod_to_device)
2833*da0073e9SAndroid Build Coastguard Worker        op_graph = script_mod.graph
2834*da0073e9SAndroid Build Coastguard Worker
2835*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2836*da0073e9SAndroid Build Coastguard Worker            op_graph
2837*da0073e9SAndroid Build Coastguard Worker        )
2838*da0073e9SAndroid Build Coastguard Worker        # successively no-ops with non-const inputs
2839*da0073e9SAndroid Build Coastguard Worker        self.run_pass(opt_pass, op_graph)
2840*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2841*da0073e9SAndroid Build Coastguard Worker            op_graph
2842*da0073e9SAndroid Build Coastguard Worker        )
2843*da0073e9SAndroid Build Coastguard Worker
2844*da0073e9SAndroid Build Coastguard Worker        script_mod = torch.jit.freeze(script_mod)
2845*da0073e9SAndroid Build Coastguard Worker        op_graph = script_mod.graph
2846*da0073e9SAndroid Build Coastguard Worker        self.run_pass(opt_pass, op_graph)
2847*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::linear", new_linears, exactly=True).run(op_graph)
2848*da0073e9SAndroid Build Coastguard Worker
2849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2850*da0073e9SAndroid Build Coastguard Worker            mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)
2851*da0073e9SAndroid Build Coastguard Worker        )
2852*da0073e9SAndroid Build Coastguard Worker
2853*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2854*da0073e9SAndroid Build Coastguard Worker    def conv():
2855*da0073e9SAndroid Build Coastguard Worker        # Generic composable conv for testing purposes
2856*da0073e9SAndroid Build Coastguard Worker        return nn.Conv2d(8, 8, 1)
2857*da0073e9SAndroid Build Coastguard Worker
2858*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2859*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2860*da0073e9SAndroid Build Coastguard Worker    )
2861*da0073e9SAndroid Build Coastguard Worker    def test_collapse_adjacent_conversions(self):
2862*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
2863*da0073e9SAndroid Build Coastguard Worker            mod = nn.Sequential(self.conv(), self.conv()).eval()
2864*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.script(mod)
2865*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.freeze(scripted_mod)
2866*da0073e9SAndroid Build Coastguard Worker            self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2867*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check(
2868*da0073e9SAndroid Build Coastguard Worker                "prim::mkldnn_convolution"
2869*da0073e9SAndroid Build Coastguard Worker            ).check("to_dense").run(scripted_mod.graph)
2870*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("to_mkldnn", 1, exactly=True).run(
2871*da0073e9SAndroid Build Coastguard Worker                scripted_mod.graph
2872*da0073e9SAndroid Build Coastguard Worker            )
2873*da0073e9SAndroid Build Coastguard Worker
2874*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand([1, 8, 8, 8])
2875*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scripted_mod(inp), mod(inp))
2876*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scripted_mod(inp), mod(inp))
2877*da0073e9SAndroid Build Coastguard Worker
2878*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2879*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2880*da0073e9SAndroid Build Coastguard Worker    )
2881*da0073e9SAndroid Build Coastguard Worker    def test_mkldnn_fuser_broadcasting(self):
2882*da0073e9SAndroid Build Coastguard Worker        class Add(nn.Module):
2883*da0073e9SAndroid Build Coastguard Worker            def __init__(self, tensor):
2884*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2885*da0073e9SAndroid Build Coastguard Worker                self.tensor = tensor
2886*da0073e9SAndroid Build Coastguard Worker
2887*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2888*da0073e9SAndroid Build Coastguard Worker                return x + self.tensor
2889*da0073e9SAndroid Build Coastguard Worker
2890*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
2891*da0073e9SAndroid Build Coastguard Worker            for add_inp in [8], [8, 8, 1]:
2892*da0073e9SAndroid Build Coastguard Worker                mod = nn.Sequential(self.conv(), Add(torch.rand(add_inp))).eval()
2893*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.script(mod)
2894*da0073e9SAndroid Build Coastguard Worker                scripted_mod = torch.jit.freeze(scripted_mod)
2895*da0073e9SAndroid Build Coastguard Worker                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2896*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("prim::BroadcastMKLDNNTensors").run(
2897*da0073e9SAndroid Build Coastguard Worker                    scripted_mod.graph
2898*da0073e9SAndroid Build Coastguard Worker                )
2899*da0073e9SAndroid Build Coastguard Worker                inp = torch.rand([1, 8, 8, 8])
2900*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scripted_mod(inp), mod(inp))
2901*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scripted_mod(inp), mod(inp))
2902*da0073e9SAndroid Build Coastguard Worker
2903*da0073e9SAndroid Build Coastguard Worker                # for good measure, check that broadcasting does not work without this op
2904*da0073e9SAndroid Build Coastguard Worker                # so we can remove the op if it ever gets supported
2905*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, ""):
2906*da0073e9SAndroid Build Coastguard Worker                    (
2907*da0073e9SAndroid Build Coastguard Worker                        torch.rand([1, 8, 8, 8]).to_mkldnn()
2908*da0073e9SAndroid Build Coastguard Worker                        + torch.rand(add_inp).to_mkldnn()
2909*da0073e9SAndroid Build Coastguard Worker                    )
2910*da0073e9SAndroid Build Coastguard Worker
2911*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2912*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2913*da0073e9SAndroid Build Coastguard Worker    )
2914*da0073e9SAndroid Build Coastguard Worker    def test_mkldnn_inplace_removal(self):
2915*da0073e9SAndroid Build Coastguard Worker        class AddMul(nn.Module):
2916*da0073e9SAndroid Build Coastguard Worker            def __init__(self, tensor):
2917*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2918*da0073e9SAndroid Build Coastguard Worker                self.tensor = tensor
2919*da0073e9SAndroid Build Coastguard Worker
2920*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2921*da0073e9SAndroid Build Coastguard Worker                return x.add_(self.tensor).div_(self.tensor) - 4
2922*da0073e9SAndroid Build Coastguard Worker
2923*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
2924*da0073e9SAndroid Build Coastguard Worker            mod = nn.Sequential(self.conv(), AddMul(torch.rand([8]))).eval()
2925*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.script(mod)
2926*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.freeze(scripted_mod)
2927*da0073e9SAndroid Build Coastguard Worker            self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2928*da0073e9SAndroid Build Coastguard Worker            # add gets uninplaced and reinplaced
2929*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("aten::to_mkldnn").check("aten::add_").check(
2930*da0073e9SAndroid Build Coastguard Worker                "aten::div_"
2931*da0073e9SAndroid Build Coastguard Worker            ).run(scripted_mod.graph)
2932*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand([1, 8, 8, 8])
2933*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scripted_mod(inp), mod(inp))
2934*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scripted_mod(inp), mod(inp))
2935*da0073e9SAndroid Build Coastguard Worker
2936*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
2937*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2938*da0073e9SAndroid Build Coastguard Worker    )
2939*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
2940*da0073e9SAndroid Build Coastguard Worker    def test_maxpool_mkldnn(self):
2941*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
2942*da0073e9SAndroid Build Coastguard Worker            model = torchvision.models.resnet18()
2943*da0073e9SAndroid Build Coastguard Worker            sub_model = torch.nn.Sequential(
2944*da0073e9SAndroid Build Coastguard Worker                model.conv1, model.bn1, model.relu, model.maxpool
2945*da0073e9SAndroid Build Coastguard Worker            )
2946*da0073e9SAndroid Build Coastguard Worker            mod = torch.jit.freeze(torch.jit.script(sub_model.eval()))
2947*da0073e9SAndroid Build Coastguard Worker            (
2948*da0073e9SAndroid Build Coastguard Worker                N,
2949*da0073e9SAndroid Build Coastguard Worker                C,
2950*da0073e9SAndroid Build Coastguard Worker                H,
2951*da0073e9SAndroid Build Coastguard Worker                W,
2952*da0073e9SAndroid Build Coastguard Worker            ) = (
2953*da0073e9SAndroid Build Coastguard Worker                10,
2954*da0073e9SAndroid Build Coastguard Worker                3,
2955*da0073e9SAndroid Build Coastguard Worker                224,
2956*da0073e9SAndroid Build Coastguard Worker                224,
2957*da0073e9SAndroid Build Coastguard Worker            )
2958*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(N, C, H, W)
2959*da0073e9SAndroid Build Coastguard Worker            self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
2960*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("max_pool").check("to_dense").run(mod.graph)
2961*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("to_dense", 1, exactly=True).run(mod.graph)
2962*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod(inp), sub_model(inp))
2963*da0073e9SAndroid Build Coastguard Worker
2964*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(torch.backends.mkldnn.is_available(), "Testing no mkldnn")
2965*da0073e9SAndroid Build Coastguard Worker    def test_conv_to_mkldnn_no_mkldnn(self):
2966*da0073e9SAndroid Build Coastguard Worker        # test no error when mkldnn not available
2967*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
2968*da0073e9SAndroid Build Coastguard Worker            mod = torch.jit.script(nn.Conv2d(3, 32, kernel_size=3, stride=2).eval())
2969*da0073e9SAndroid Build Coastguard Worker            frozen = torch.jit.freeze(mod)
2970*da0073e9SAndroid Build Coastguard Worker            self.run_pass("convert_frozen_ops_to_mkldnn", frozen.graph)
2971*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand([4, 3, 4, 4])
2972*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(frozen(inp), mod(inp))
2973*da0073e9SAndroid Build Coastguard Worker
2974*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN")
2975*da0073e9SAndroid Build Coastguard Worker    def test_freeze_conv_relu_fusion(self):
2976*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
2977*da0073e9SAndroid Build Coastguard Worker            conv_bias = [True, False]
2978*da0073e9SAndroid Build Coastguard Worker            conv_ops = [nn.Conv2d, nn.Conv3d]
2979*da0073e9SAndroid Build Coastguard Worker            use_add_z = [True, False]
2980*da0073e9SAndroid Build Coastguard Worker            use_tracing = [True, False]
2981*da0073e9SAndroid Build Coastguard Worker            for use_bias, conv, add_z, tracing in product(
2982*da0073e9SAndroid Build Coastguard Worker                conv_bias, conv_ops, use_add_z, use_tracing
2983*da0073e9SAndroid Build Coastguard Worker            ):
2984*da0073e9SAndroid Build Coastguard Worker
2985*da0073e9SAndroid Build Coastguard Worker                class Net(nn.Module):
2986*da0073e9SAndroid Build Coastguard Worker                    def __init__(self, in_channels, out_channels, **kwargs):
2987*da0073e9SAndroid Build Coastguard Worker                        super().__init__()
2988*da0073e9SAndroid Build Coastguard Worker                        self.conv = conv(
2989*da0073e9SAndroid Build Coastguard Worker                            in_channels, out_channels, bias=use_bias, **kwargs
2990*da0073e9SAndroid Build Coastguard Worker                        )
2991*da0073e9SAndroid Build Coastguard Worker                        self.relu = nn.ReLU(inplace=True)
2992*da0073e9SAndroid Build Coastguard Worker                        self.add_z = add_z
2993*da0073e9SAndroid Build Coastguard Worker
2994*da0073e9SAndroid Build Coastguard Worker                    def forward(self, x):
2995*da0073e9SAndroid Build Coastguard Worker                        z = self.conv(x)
2996*da0073e9SAndroid Build Coastguard Worker                        out = self.conv(x)
2997*da0073e9SAndroid Build Coastguard Worker                        if self.add_z:
2998*da0073e9SAndroid Build Coastguard Worker                            out += z
2999*da0073e9SAndroid Build Coastguard Worker                        out = self.relu(out)
3000*da0073e9SAndroid Build Coastguard Worker                        return out
3001*da0073e9SAndroid Build Coastguard Worker
3002*da0073e9SAndroid Build Coastguard Worker                mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker                inps = [5, 3, 4, 4]
3005*da0073e9SAndroid Build Coastguard Worker                if conv == nn.Conv3d:
3006*da0073e9SAndroid Build Coastguard Worker                    inps.append(inps[-1])
3007*da0073e9SAndroid Build Coastguard Worker                inp = torch.rand(inps).cuda()
3008*da0073e9SAndroid Build Coastguard Worker
3009*da0073e9SAndroid Build Coastguard Worker                if tracing:
3010*da0073e9SAndroid Build Coastguard Worker                    scripted_mod = torch.jit.trace(mod_eager, (inp))
3011*da0073e9SAndroid Build Coastguard Worker                else:
3012*da0073e9SAndroid Build Coastguard Worker                    scripted_mod = torch.jit.script(mod_eager)
3013*da0073e9SAndroid Build Coastguard Worker
3014*da0073e9SAndroid Build Coastguard Worker                frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
3015*da0073e9SAndroid Build Coastguard Worker                if TEST_WITH_ROCM:
3016*da0073e9SAndroid Build Coastguard Worker                    if add_z:
3017*da0073e9SAndroid Build Coastguard Worker                        FileCheck().check("aten::miopen_convolution_add_relu").run(
3018*da0073e9SAndroid Build Coastguard Worker                            frozen_mod.graph
3019*da0073e9SAndroid Build Coastguard Worker                        )
3020*da0073e9SAndroid Build Coastguard Worker                    else:
3021*da0073e9SAndroid Build Coastguard Worker                        FileCheck().check("aten::miopen_convolution_relu").run(
3022*da0073e9SAndroid Build Coastguard Worker                            frozen_mod.graph
3023*da0073e9SAndroid Build Coastguard Worker                        )
3024*da0073e9SAndroid Build Coastguard Worker                else:
3025*da0073e9SAndroid Build Coastguard Worker                    if add_z:
3026*da0073e9SAndroid Build Coastguard Worker                        FileCheck().check("aten::cudnn_convolution_add_relu").run(
3027*da0073e9SAndroid Build Coastguard Worker                            frozen_mod.graph
3028*da0073e9SAndroid Build Coastguard Worker                        )
3029*da0073e9SAndroid Build Coastguard Worker                    else:
3030*da0073e9SAndroid Build Coastguard Worker                        FileCheck().check("aten::cudnn_convolution_relu").run(
3031*da0073e9SAndroid Build Coastguard Worker                            frozen_mod.graph
3032*da0073e9SAndroid Build Coastguard Worker                        )
3033*da0073e9SAndroid Build Coastguard Worker
3034*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mod_eager(inp), frozen_mod(inp))
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN")
3037*da0073e9SAndroid Build Coastguard Worker    def test_freeze_conv_relu_fusion_not_forward(self):
3038*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3039*da0073e9SAndroid Build Coastguard Worker
3040*da0073e9SAndroid Build Coastguard Worker            class Net(nn.Module):
3041*da0073e9SAndroid Build Coastguard Worker                def __init__(self, in_channels, out_channels, **kwargs):
3042*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
3043*da0073e9SAndroid Build Coastguard Worker                    self.conv = nn.Conv2d(
3044*da0073e9SAndroid Build Coastguard Worker                        in_channels, out_channels, bias=None, **kwargs
3045*da0073e9SAndroid Build Coastguard Worker                    )
3046*da0073e9SAndroid Build Coastguard Worker                    self.relu = nn.ReLU(inplace=True)
3047*da0073e9SAndroid Build Coastguard Worker
3048*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
3049*da0073e9SAndroid Build Coastguard Worker                    z = self.conv(x)
3050*da0073e9SAndroid Build Coastguard Worker                    out = self.conv(x)
3051*da0073e9SAndroid Build Coastguard Worker                    out = self.relu(out)
3052*da0073e9SAndroid Build Coastguard Worker                    return out
3053*da0073e9SAndroid Build Coastguard Worker
3054*da0073e9SAndroid Build Coastguard Worker                @torch.jit.export
3055*da0073e9SAndroid Build Coastguard Worker                def make_prediction(self, x):
3056*da0073e9SAndroid Build Coastguard Worker                    return self.forward(x)
3057*da0073e9SAndroid Build Coastguard Worker
3058*da0073e9SAndroid Build Coastguard Worker            mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()
3059*da0073e9SAndroid Build Coastguard Worker
3060*da0073e9SAndroid Build Coastguard Worker            inps = [5, 3, 4, 4]
3061*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(inps).cuda()
3062*da0073e9SAndroid Build Coastguard Worker
3063*da0073e9SAndroid Build Coastguard Worker            scripted_mod = torch.jit.script(mod_eager)
3064*da0073e9SAndroid Build Coastguard Worker
3065*da0073e9SAndroid Build Coastguard Worker            frozen_mod = torch.jit.freeze(
3066*da0073e9SAndroid Build Coastguard Worker                scripted_mod, preserved_attrs=["make_prediction"]
3067*da0073e9SAndroid Build Coastguard Worker            )
3068*da0073e9SAndroid Build Coastguard Worker            optimized_mod = torch.jit.optimize_for_inference(
3069*da0073e9SAndroid Build Coastguard Worker                frozen_mod, other_methods=["make_prediction"]
3070*da0073e9SAndroid Build Coastguard Worker            )
3071*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_ROCM:
3072*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("aten::miopen_convolution_relu").run(
3073*da0073e9SAndroid Build Coastguard Worker                    optimized_mod.make_prediction.graph
3074*da0073e9SAndroid Build Coastguard Worker                )
3075*da0073e9SAndroid Build Coastguard Worker            else:
3076*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("aten::cudnn_convolution_relu").run(
3077*da0073e9SAndroid Build Coastguard Worker                    optimized_mod.make_prediction.graph
3078*da0073e9SAndroid Build Coastguard Worker                )
3079*da0073e9SAndroid Build Coastguard Worker
3080*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3081*da0073e9SAndroid Build Coastguard Worker                mod_eager.make_prediction(inp), optimized_mod.make_prediction(inp)
3082*da0073e9SAndroid Build Coastguard Worker            )
3083*da0073e9SAndroid Build Coastguard Worker
3084*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3085*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3086*da0073e9SAndroid Build Coastguard Worker    )
3087*da0073e9SAndroid Build Coastguard Worker    def test_numel_less_than_size_with_padding(self):
3088*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3089*da0073e9SAndroid Build Coastguard Worker
3090*da0073e9SAndroid Build Coastguard Worker            class MyModule(nn.Module):
3091*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
3092*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
3093*da0073e9SAndroid Build Coastguard Worker                    self.conv1 = nn.Conv2d(
3094*da0073e9SAndroid Build Coastguard Worker                        1,
3095*da0073e9SAndroid Build Coastguard Worker                        2,
3096*da0073e9SAndroid Build Coastguard Worker                        kernel_size=(2, 4),
3097*da0073e9SAndroid Build Coastguard Worker                        stride=2,
3098*da0073e9SAndroid Build Coastguard Worker                        padding=2,
3099*da0073e9SAndroid Build Coastguard Worker                        dilation=(2, 1),
3100*da0073e9SAndroid Build Coastguard Worker                    )
3101*da0073e9SAndroid Build Coastguard Worker
3102*da0073e9SAndroid Build Coastguard Worker                def forward(self, i0):
3103*da0073e9SAndroid Build Coastguard Worker                    x = self.conv1(i0)
3104*da0073e9SAndroid Build Coastguard Worker                    o0 = torch.max(x, i0)
3105*da0073e9SAndroid Build Coastguard Worker                    o1 = torch.clip(x, -1.5, 1.5)
3106*da0073e9SAndroid Build Coastguard Worker                    return o0, o1
3107*da0073e9SAndroid Build Coastguard Worker
3108*da0073e9SAndroid Build Coastguard Worker            i0 = torch.zeros((1, 1, 1, 2), dtype=torch.float32)
3109*da0073e9SAndroid Build Coastguard Worker            mod = MyModule()
3110*da0073e9SAndroid Build Coastguard Worker            out = mod(i0)
3111*da0073e9SAndroid Build Coastguard Worker
3112*da0073e9SAndroid Build Coastguard Worker            exported = torch.jit.trace(mod, [i0])
3113*da0073e9SAndroid Build Coastguard Worker            exported = torch.jit.optimize_for_inference(exported)
3114*da0073e9SAndroid Build Coastguard Worker
3115*da0073e9SAndroid Build Coastguard Worker            eout = exported(i0)
3116*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(all(torch.allclose(x, y) for x, y in zip(out, eout)))
3117*da0073e9SAndroid Build Coastguard Worker
3118*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3119*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3120*da0073e9SAndroid Build Coastguard Worker    )
3121*da0073e9SAndroid Build Coastguard Worker    def test_incompatible_perf_formats(self):
3122*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3123*da0073e9SAndroid Build Coastguard Worker
3124*da0073e9SAndroid Build Coastguard Worker            class Mod(nn.Module):
3125*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
3126*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
3127*da0073e9SAndroid Build Coastguard Worker                    self.conv = torch.nn.Conv2d(3, 64, 3, 2)
3128*da0073e9SAndroid Build Coastguard Worker                    self.max_pool = torch.nn.MaxPool2d(111, 111)
3129*da0073e9SAndroid Build Coastguard Worker
3130*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
3131*da0073e9SAndroid Build Coastguard Worker                    a = self.conv(x)
3132*da0073e9SAndroid Build Coastguard Worker                    b = self.max_pool(a)
3133*da0073e9SAndroid Build Coastguard Worker                    return a + b
3134*da0073e9SAndroid Build Coastguard Worker
3135*da0073e9SAndroid Build Coastguard Worker            model = Mod()
3136*da0073e9SAndroid Build Coastguard Worker            model.eval()
3137*da0073e9SAndroid Build Coastguard Worker            mod = torch.jit.freeze(torch.jit.script(model))
3138*da0073e9SAndroid Build Coastguard Worker            (
3139*da0073e9SAndroid Build Coastguard Worker                N,
3140*da0073e9SAndroid Build Coastguard Worker                C,
3141*da0073e9SAndroid Build Coastguard Worker                H,
3142*da0073e9SAndroid Build Coastguard Worker                W,
3143*da0073e9SAndroid Build Coastguard Worker            ) = (
3144*da0073e9SAndroid Build Coastguard Worker                10,
3145*da0073e9SAndroid Build Coastguard Worker                3,
3146*da0073e9SAndroid Build Coastguard Worker                224,
3147*da0073e9SAndroid Build Coastguard Worker                224,
3148*da0073e9SAndroid Build Coastguard Worker            )
3149*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(N, C, H, W)
3150*da0073e9SAndroid Build Coastguard Worker            self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3151*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(model(inp), mod(inp))
3152*da0073e9SAndroid Build Coastguard Worker
3153*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3154*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3155*da0073e9SAndroid Build Coastguard Worker    )
3156*da0073e9SAndroid Build Coastguard Worker    def test_pool2d_batchnorm(self):
3157*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3158*da0073e9SAndroid Build Coastguard Worker            pooling_layers = [
3159*da0073e9SAndroid Build Coastguard Worker                torch.nn.AdaptiveAvgPool2d(4),
3160*da0073e9SAndroid Build Coastguard Worker                # torch.nn.AdaptiveMaxPool2d(4), # return tuples
3161*da0073e9SAndroid Build Coastguard Worker                torch.nn.MaxPool2d(4),
3162*da0073e9SAndroid Build Coastguard Worker                torch.nn.AvgPool2d(4),
3163*da0073e9SAndroid Build Coastguard Worker                torch.nn.BatchNorm2d(64).eval(),
3164*da0073e9SAndroid Build Coastguard Worker            ]
3165*da0073e9SAndroid Build Coastguard Worker
3166*da0073e9SAndroid Build Coastguard Worker            for pl in pooling_layers:
3167*da0073e9SAndroid Build Coastguard Worker                sub_model = torch.nn.Sequential(
3168*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Conv2d(3, 64, 2, 2),
3169*da0073e9SAndroid Build Coastguard Worker                    torch.nn.ReLU(),
3170*da0073e9SAndroid Build Coastguard Worker                    pl,
3171*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Hardswish(),
3172*da0073e9SAndroid Build Coastguard Worker                )
3173*da0073e9SAndroid Build Coastguard Worker                sub_model.eval()
3174*da0073e9SAndroid Build Coastguard Worker                mod = torch.jit.freeze(torch.jit.script(sub_model))
3175*da0073e9SAndroid Build Coastguard Worker                (
3176*da0073e9SAndroid Build Coastguard Worker                    N,
3177*da0073e9SAndroid Build Coastguard Worker                    C,
3178*da0073e9SAndroid Build Coastguard Worker                    H,
3179*da0073e9SAndroid Build Coastguard Worker                    W,
3180*da0073e9SAndroid Build Coastguard Worker                ) = (
3181*da0073e9SAndroid Build Coastguard Worker                    10,
3182*da0073e9SAndroid Build Coastguard Worker                    3,
3183*da0073e9SAndroid Build Coastguard Worker                    224,
3184*da0073e9SAndroid Build Coastguard Worker                    224,
3185*da0073e9SAndroid Build Coastguard Worker                )
3186*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(N, C, H, W)
3187*da0073e9SAndroid Build Coastguard Worker                # these two passes needed to remove
3188*da0073e9SAndroid Build Coastguard Worker                # a size check in BatchNorm2d
3189*da0073e9SAndroid Build Coastguard Worker                removeExceptions(mod.graph)
3190*da0073e9SAndroid Build Coastguard Worker                self.run_pass("dce", mod.graph)
3191*da0073e9SAndroid Build Coastguard Worker                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3192*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
3193*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sub_model(inp), mod(inp))
3194*da0073e9SAndroid Build Coastguard Worker
3195*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3196*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3197*da0073e9SAndroid Build Coastguard Worker    )
3198*da0073e9SAndroid Build Coastguard Worker    def test_pool3d_batchnorm(self):
3199*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3200*da0073e9SAndroid Build Coastguard Worker            pooling_layers = [
3201*da0073e9SAndroid Build Coastguard Worker                torch.nn.MaxPool3d(4),
3202*da0073e9SAndroid Build Coastguard Worker                # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings
3203*da0073e9SAndroid Build Coastguard Worker                # torch.nn.AdaptiveMaxPool3d(4), # return tuples
3204*da0073e9SAndroid Build Coastguard Worker                torch.nn.AvgPool3d(4),
3205*da0073e9SAndroid Build Coastguard Worker                torch.nn.BatchNorm3d(64).eval(),
3206*da0073e9SAndroid Build Coastguard Worker            ]
3207*da0073e9SAndroid Build Coastguard Worker
3208*da0073e9SAndroid Build Coastguard Worker            for pl in pooling_layers:
3209*da0073e9SAndroid Build Coastguard Worker                sub_model = torch.nn.Sequential(
3210*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Conv3d(3, 64, 2, 2),
3211*da0073e9SAndroid Build Coastguard Worker                    torch.nn.ReLU(),
3212*da0073e9SAndroid Build Coastguard Worker                    pl,
3213*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Hardswish(),
3214*da0073e9SAndroid Build Coastguard Worker                )
3215*da0073e9SAndroid Build Coastguard Worker                sub_model.eval()
3216*da0073e9SAndroid Build Coastguard Worker                mod = torch.jit.freeze(torch.jit.script(sub_model))
3217*da0073e9SAndroid Build Coastguard Worker                N, C, H, W, D = 10, 3, 64, 64, 64
3218*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(N, C, D, H, W)
3219*da0073e9SAndroid Build Coastguard Worker                # these two passes needed to remove
3220*da0073e9SAndroid Build Coastguard Worker                # a size check in BatchNorm2d
3221*da0073e9SAndroid Build Coastguard Worker                removeExceptions(mod.graph)
3222*da0073e9SAndroid Build Coastguard Worker                self.run_pass("dce", mod.graph)
3223*da0073e9SAndroid Build Coastguard Worker                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3224*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
3225*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sub_model(inp), mod(inp))
3226*da0073e9SAndroid Build Coastguard Worker
3227*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3228*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3229*da0073e9SAndroid Build Coastguard Worker    )
3230*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
3231*da0073e9SAndroid Build Coastguard Worker    def test_conv_hardswish(self):
3232*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3233*da0073e9SAndroid Build Coastguard Worker
3234*da0073e9SAndroid Build Coastguard Worker            class Clamp(torch.nn.Module):
3235*da0073e9SAndroid Build Coastguard Worker                def __init__(self, min_val, max_val, **kwargs):
3236*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
3237*da0073e9SAndroid Build Coastguard Worker                    self.min_val = min_val
3238*da0073e9SAndroid Build Coastguard Worker                    self.max_val = max_val
3239*da0073e9SAndroid Build Coastguard Worker
3240*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
3241*da0073e9SAndroid Build Coastguard Worker                    return torch.clamp(x, self.min_val, self.max_val)
3242*da0073e9SAndroid Build Coastguard Worker
3243*da0073e9SAndroid Build Coastguard Worker            (
3244*da0073e9SAndroid Build Coastguard Worker                N,
3245*da0073e9SAndroid Build Coastguard Worker                C,
3246*da0073e9SAndroid Build Coastguard Worker                H,
3247*da0073e9SAndroid Build Coastguard Worker                W,
3248*da0073e9SAndroid Build Coastguard Worker            ) = (
3249*da0073e9SAndroid Build Coastguard Worker                10,
3250*da0073e9SAndroid Build Coastguard Worker                3,
3251*da0073e9SAndroid Build Coastguard Worker                224,
3252*da0073e9SAndroid Build Coastguard Worker                224,
3253*da0073e9SAndroid Build Coastguard Worker            )
3254*da0073e9SAndroid Build Coastguard Worker            activations = [
3255*da0073e9SAndroid Build Coastguard Worker                torch.nn.Hardswish(),
3256*da0073e9SAndroid Build Coastguard Worker                torch.nn.Hardsigmoid(),
3257*da0073e9SAndroid Build Coastguard Worker                torch.nn.ReLU6(),
3258*da0073e9SAndroid Build Coastguard Worker                torch.nn.Tanh(),
3259*da0073e9SAndroid Build Coastguard Worker                torch.nn.Hardtanh(0.0, 6.0),
3260*da0073e9SAndroid Build Coastguard Worker                torch.nn.Hardtanh(1.0, 100.0),
3261*da0073e9SAndroid Build Coastguard Worker                torch.nn.Hardtanh(-100.0, -1.0),
3262*da0073e9SAndroid Build Coastguard Worker                torch.nn.GELU(),
3263*da0073e9SAndroid Build Coastguard Worker                Clamp(-100.0, -1.0),
3264*da0073e9SAndroid Build Coastguard Worker                Clamp(1.0, 100.0),
3265*da0073e9SAndroid Build Coastguard Worker                Clamp(0.0, 6.0),
3266*da0073e9SAndroid Build Coastguard Worker                Clamp(-1.0, 0.0),
3267*da0073e9SAndroid Build Coastguard Worker            ]
3268*da0073e9SAndroid Build Coastguard Worker
3269*da0073e9SAndroid Build Coastguard Worker            model = torchvision.models.resnet18()
3270*da0073e9SAndroid Build Coastguard Worker            for activation in activations:
3271*da0073e9SAndroid Build Coastguard Worker                sub_model = torch.nn.Sequential(model.conv1, activation)
3272*da0073e9SAndroid Build Coastguard Worker                sub_model.eval()
3273*da0073e9SAndroid Build Coastguard Worker                mod = torch.jit.freeze(torch.jit.script(sub_model))
3274*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(N, C, H, W)
3275*da0073e9SAndroid Build Coastguard Worker                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3276*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count("aten::to_dense", 1, exactly=True).run(
3277*da0073e9SAndroid Build Coastguard Worker                    mod.graph
3278*da0073e9SAndroid Build Coastguard Worker                )
3279*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sub_model(inp), mod(inp))
3280*da0073e9SAndroid Build Coastguard Worker
3281*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3282*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3283*da0073e9SAndroid Build Coastguard Worker    )
3284*da0073e9SAndroid Build Coastguard Worker    def test_hardswish_hardsigmoid(self):
3285*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3286*da0073e9SAndroid Build Coastguard Worker            op_map = {
3287*da0073e9SAndroid Build Coastguard Worker                "prim::MKLDNNHardSwish": F.hardswish,
3288*da0073e9SAndroid Build Coastguard Worker                "prim::MKLDNNHardSigmoid": F.hardsigmoid,
3289*da0073e9SAndroid Build Coastguard Worker            }
3290*da0073e9SAndroid Build Coastguard Worker
3291*da0073e9SAndroid Build Coastguard Worker            input_sizes = ([0], [1], [3], [1, 3, 8, 8])
3292*da0073e9SAndroid Build Coastguard Worker            for mkldnn_opname, aten_op in op_map.items():
3293*da0073e9SAndroid Build Coastguard Worker                for size in input_sizes:
3294*da0073e9SAndroid Build Coastguard Worker                    for inplace in (True, False):
3295*da0073e9SAndroid Build Coastguard Worker                        inplace_str = "_" if inplace else ""
3296*da0073e9SAndroid Build Coastguard Worker                        inplace_tgt = "%34" if inplace else "%35"
3297*da0073e9SAndroid Build Coastguard Worker                        graph_str = f"""graph(%input.1 : Tensor):
3298*da0073e9SAndroid Build Coastguard Worker                            %33 : None = prim::Constant()
3299*da0073e9SAndroid Build Coastguard Worker                            %34 : Tensor = aten::to_mkldnn(%input.1, %33)
3300*da0073e9SAndroid Build Coastguard Worker                            %35 : Tensor = {mkldnn_opname}{inplace_str}(%34)
3301*da0073e9SAndroid Build Coastguard Worker                            return ({inplace_tgt})
3302*da0073e9SAndroid Build Coastguard Worker                        """
3303*da0073e9SAndroid Build Coastguard Worker                        g = torch._C.parse_ir(graph_str)
3304*da0073e9SAndroid Build Coastguard Worker                        m = self.createFunctionFromGraph(g)
3305*da0073e9SAndroid Build Coastguard Worker                        x = torch.rand(size)
3306*da0073e9SAndroid Build Coastguard Worker                        # `inplace=False` is intentional, otherwise we modify the input
3307*da0073e9SAndroid Build Coastguard Worker                        # and we aren't testing aten impls anyways
3308*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(aten_op(x, inplace=False), m(x).to_dense())
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3311*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3312*da0073e9SAndroid Build Coastguard Worker    )
3313*da0073e9SAndroid Build Coastguard Worker    def test_scalar_mul(self):
3314*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
3315*da0073e9SAndroid Build Coastguard Worker
3316*da0073e9SAndroid Build Coastguard Worker            class Mod(nn.Module):
3317*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
3318*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
3319*da0073e9SAndroid Build Coastguard Worker                    self.mod = nn.Conv2d(8, 8, 1, padding=1)
3320*da0073e9SAndroid Build Coastguard Worker
3321*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
3322*da0073e9SAndroid Build Coastguard Worker                    a1 = self.mod(x) * 4
3323*da0073e9SAndroid Build Coastguard Worker                    return a1 * 4 + a1 * 5.0
3324*da0073e9SAndroid Build Coastguard Worker
3325*da0073e9SAndroid Build Coastguard Worker            mod = Mod().eval()
3326*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.freeze(torch.jit.script(mod))
3327*da0073e9SAndroid Build Coastguard Worker            optimized = torch.jit.optimize_for_inference(scripted)
3328*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand([1, 8, 8, 8])
3329*da0073e9SAndroid Build Coastguard Worker            # a1 cant be inplaced for first use, can for second
3330*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("ScalarMul(").check("ScalarMul_").run(optimized.graph)
3331*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(optimized(inp), mod(inp))
3332*da0073e9SAndroid Build Coastguard Worker
3333*da0073e9SAndroid Build Coastguard Worker    def test_remove_detach(self):
3334*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
3335*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3336*da0073e9SAndroid Build Coastguard Worker                y = x.detach()
3337*da0073e9SAndroid Build Coastguard Worker                return y * y
3338*da0073e9SAndroid Build Coastguard Worker
3339*da0073e9SAndroid Build Coastguard Worker        mod = Mod().eval()
3340*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
3341*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn((2, 2))
3342*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::detach").run(frozen_mod.graph)
3343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(frozen_mod(inp), mod(inp))
3344*da0073e9SAndroid Build Coastguard Worker
3345*da0073e9SAndroid Build Coastguard Worker    def test_remove_detach_not_applied(self):
3346*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
3347*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3348*da0073e9SAndroid Build Coastguard Worker                y = x.detach()
3349*da0073e9SAndroid Build Coastguard Worker                return x is y
3350*da0073e9SAndroid Build Coastguard Worker
3351*da0073e9SAndroid Build Coastguard Worker        mod = Mod().eval()
3352*da0073e9SAndroid Build Coastguard Worker        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
3353*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn((2, 2))
3354*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::detach").run(frozen_mod.graph)
3355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(frozen_mod(inp), mod(inp))
3356*da0073e9SAndroid Build Coastguard Worker
3357*da0073e9SAndroid Build Coastguard Worker
3358*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("somehow causing hanging during python shutdown")
3359*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
3360*da0073e9SAndroid Build Coastguard Workerclass TestMKLDNNReinplacing(JitTestCase):
3361*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
3362*da0073e9SAndroid Build Coastguard Worker        super().setUp()
3363*da0073e9SAndroid Build Coastguard Worker        self.default_dtype = torch.get_default_dtype()
3364*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(torch.float)
3365*da0073e9SAndroid Build Coastguard Worker
3366*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
3367*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
3368*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(self.default_dtype)
3369*da0073e9SAndroid Build Coastguard Worker
3370*da0073e9SAndroid Build Coastguard Worker    def getConv(self):
3371*da0073e9SAndroid Build Coastguard Worker        return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()
3372*da0073e9SAndroid Build Coastguard Worker
3373*da0073e9SAndroid Build Coastguard Worker    def getInput(self):
3374*da0073e9SAndroid Build Coastguard Worker        return torch.rand([4, 3, 4, 4])
3375*da0073e9SAndroid Build Coastguard Worker
3376*da0073e9SAndroid Build Coastguard Worker    def freezeAndConvert(self, mod):
3377*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.freeze(torch.jit.script(mod.eval()))
3378*da0073e9SAndroid Build Coastguard Worker        self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3379*da0073e9SAndroid Build Coastguard Worker        return mod
3380*da0073e9SAndroid Build Coastguard Worker
3381*da0073e9SAndroid Build Coastguard Worker    def checkResults(self, mod1, mod2):
3382*da0073e9SAndroid Build Coastguard Worker        inp = self.getInput()
3383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod1(inp), mod2(inp))
3384*da0073e9SAndroid Build Coastguard Worker
3385*da0073e9SAndroid Build Coastguard Worker    def test_successful(self):
3386*da0073e9SAndroid Build Coastguard Worker        # simple conv-relu
3387*da0073e9SAndroid Build Coastguard Worker
3388*da0073e9SAndroid Build Coastguard Worker        mod_eager = nn.Sequential(self.getConv(), nn.Hardswish(), nn.ReLU())
3389*da0073e9SAndroid Build Coastguard Worker        mod = self.freezeAndConvert(mod_eager)
3390*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("mkldnn_convolution").check_next(
3391*da0073e9SAndroid Build Coastguard Worker            "prim::MKLDNNHardSwish_"
3392*da0073e9SAndroid Build Coastguard Worker        ).check_next("aten::relu_").run(mod.graph)
3393*da0073e9SAndroid Build Coastguard Worker        self.checkResults(mod_eager, mod)
3394*da0073e9SAndroid Build Coastguard Worker
3395*da0073e9SAndroid Build Coastguard Worker    def test_merge_liveness(self):
3396*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
3397*da0073e9SAndroid Build Coastguard Worker            def __init__(self, tensor):
3398*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3399*da0073e9SAndroid Build Coastguard Worker                self.tensor = tensor
3400*da0073e9SAndroid Build Coastguard Worker
3401*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3402*da0073e9SAndroid Build Coastguard Worker                # this mul can be inplaced since x is dead after this use
3403*da0073e9SAndroid Build Coastguard Worker                temporary = x * self.tensor
3404*da0073e9SAndroid Build Coastguard Worker                # temporary livespan is the return node,
3405*da0073e9SAndroid Build Coastguard Worker                # add can not be inplaced
3406*da0073e9SAndroid Build Coastguard Worker                return temporary + temporary, temporary
3407*da0073e9SAndroid Build Coastguard Worker
3408*da0073e9SAndroid Build Coastguard Worker        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
3409*da0073e9SAndroid Build Coastguard Worker        mod = self.freezeAndConvert(mod_eager)
3410*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph)
3411*da0073e9SAndroid Build Coastguard Worker        self.checkResults(mod_eager, mod)
3412*da0073e9SAndroid Build Coastguard Worker
3413*da0073e9SAndroid Build Coastguard Worker    def test_always_alive_values(self):
3414*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
3415*da0073e9SAndroid Build Coastguard Worker            def __init__(self, tensor):
3416*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3417*da0073e9SAndroid Build Coastguard Worker                self.tensor = tensor
3418*da0073e9SAndroid Build Coastguard Worker
3419*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3420*da0073e9SAndroid Build Coastguard Worker                # x can't be inplaced because its a return value,
3421*da0073e9SAndroid Build Coastguard Worker                # check that the inplacing pass doesnt try to inplace
3422*da0073e9SAndroid Build Coastguard Worker                # self.tensor because its always alive
3423*da0073e9SAndroid Build Coastguard Worker                return x * self.tensor, x
3424*da0073e9SAndroid Build Coastguard Worker
3425*da0073e9SAndroid Build Coastguard Worker        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
3426*da0073e9SAndroid Build Coastguard Worker        mod = self.freezeAndConvert(mod_eager)
3427*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::mul_").run(mod.graph)
3428*da0073e9SAndroid Build Coastguard Worker        self.checkResults(mod_eager, mod)
3429*da0073e9SAndroid Build Coastguard Worker
3430*da0073e9SAndroid Build Coastguard Worker        conv = self.getConv()
3431*da0073e9SAndroid Build Coastguard Worker
3432*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
3433*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3434*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3435*da0073e9SAndroid Build Coastguard Worker                self.tensor = torch.rand([4, 32, 1, 1])
3436*da0073e9SAndroid Build Coastguard Worker                self.conv = conv
3437*da0073e9SAndroid Build Coastguard Worker
3438*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3439*da0073e9SAndroid Build Coastguard Worker                # the shapes dont add up on this just testing a particular pattern
3440*da0073e9SAndroid Build Coastguard Worker                conv_output = self.conv(x)
3441*da0073e9SAndroid Build Coastguard Worker                return conv_output, self.conv(torch.add(x, x))
3442*da0073e9SAndroid Build Coastguard Worker
3443*da0073e9SAndroid Build Coastguard Worker        mod = self.freezeAndConvert(Mod())
3444*da0073e9SAndroid Build Coastguard Worker        # x is an input to the graph, and so it should not be inplaced
3445*da0073e9SAndroid Build Coastguard Worker        # in the torch.add(x, x) call
3446*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::add_").run(mod.graph)
3447*da0073e9SAndroid Build Coastguard Worker
3448*da0073e9SAndroid Build Coastguard Worker    def test_switch_inputs_to_inplace(self):
3449*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
3450*da0073e9SAndroid Build Coastguard Worker            def __init__(self, tensor):
3451*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3452*da0073e9SAndroid Build Coastguard Worker                self.tensor = tensor
3453*da0073e9SAndroid Build Coastguard Worker
3454*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3455*da0073e9SAndroid Build Coastguard Worker                # self.tensor cannot be inplaced, however x can,
3456*da0073e9SAndroid Build Coastguard Worker                # and bc add is commutative we can reverse inputs to add_
3457*da0073e9SAndroid Build Coastguard Worker                return self.tensor + x
3458*da0073e9SAndroid Build Coastguard Worker
3459*da0073e9SAndroid Build Coastguard Worker        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
3460*da0073e9SAndroid Build Coastguard Worker        mod = self.freezeAndConvert(mod_eager)
3461*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::add_").run(mod.graph)
3462*da0073e9SAndroid Build Coastguard Worker        self.checkResults(mod_eager, mod)
3463