xref: /aosp_15_r20/external/pytorch/test/jit/test_freezing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import unittest
5from itertools import product
6from typing import Any
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.jit._recursive import wrap_cpp_module
12from torch.testing import FileCheck
13from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
14from torch.testing._internal.common_quantization import skipIfNoFBGEMM
15from torch.testing._internal.common_quantized import override_quantized_engine
16from torch.testing._internal.common_utils import (
17    set_default_dtype,
18    skipCUDAMemoryLeakCheckIf,
19    skipIfTorchDynamo,
20    TEST_WITH_ROCM,
21)
22from torch.testing._internal.jit_utils import JitTestCase
23from torch.utils import mkldnn as mkldnn_utils
24
25
26try:
27    import torchvision
28
29    HAS_TORCHVISION = True
30except ImportError:
31    HAS_TORCHVISION = False
32skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
33
34if __name__ == "__main__":
35    raise RuntimeError(
36        "This test file is not meant to be run directly, use:\n\n"
37        "\tpython test/test_jit.py TESTNAME\n\n"
38        "instead."
39    )
40
41TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None
42
43
44def removeExceptions(graph):
45    for n in graph.findAllNodes("prim::RaiseException"):
46        n.destroy()
47
48
49class TestFreezing(JitTestCase):
50    def test_freeze_module(self):
51        class M(nn.Module):
52            def __init__(self) -> None:
53                super().__init__()
54                self.a = 1  # folded
55                self.b = 1.2  # folded
56                self.c = "hello"  # folded
57                self.c2 = "hi\xA1"  # not folded
58                self.d = [1, 1]  # folded
59                self.e = [1.0, 1.1]  # folded
60                self.f = ["hello", "world"]  # folded
61                self.f2 = [(1, "Over \u0e55\u0e57 57")]
62                self.g = (
63                    [1, 2],
64                    3.2,
65                    "4.4",
66                    torch.tensor([5.5], requires_grad=True),
67                )  # folded
68                self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]}
69                self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]}
70                self.t = torch.tensor([1.2, 2.4], requires_grad=True)  # folded
71                self.ts = [
72                    torch.tensor([1.0, 2.0], requires_grad=True),
73                    torch.tensor([3.0, 4.0], requires_grad=True),
74                ]  # folded
75                self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]]
76
77            def forward(self, x):
78                return (
79                    str(self.a)
80                    + str(self.b)
81                    + self.c
82                    + self.c2
83                    + str(self.d)
84                    + str(self.e)
85                    + str(self.f)
86                    + str(self.f2)
87                    + str(self.g)
88                    + str(self.h)
89                    + str(self.h2)
90                    + str(self.t)
91                    + str(self.ts)
92                    + str(self.tt)
93                )
94
95        m = torch.jit.script(M())
96        m.eval()
97        input = torch.randn(2, 2)
98        output_s = m.forward(input)
99        m._c = torch._C._freeze_module(m._c)
100        buffer = io.BytesIO()
101        torch.jit.save(m._c, buffer)
102        buffer.seek(0)
103        m2 = torch.jit.load(buffer)
104        # Check if frozen module looks as below:
105        # module m {
106        #   attributes {
107        #     tt = ...
108        #   }
109        #   ...
110        # }
111        self.assertFalse(m2._c.hasattr("a"))
112        self.assertFalse(m2._c.hasattr("b"))
113        self.assertFalse(m2._c.hasattr("c"))
114        self.assertFalse(m2._c.hasattr("c2"))
115        self.assertFalse(m2._c.hasattr("d"))
116        self.assertFalse(m2._c.hasattr("e"))
117        self.assertFalse(m2._c.hasattr("f"))
118        self.assertFalse(m2._c.hasattr("f2"))
119        self.assertFalse(m2._c.hasattr("g"))
120        self.assertFalse(m2._c.hasattr("h"))
121        self.assertFalse(m2._c.hasattr("h2"))
122        self.assertFalse(m2._c.hasattr("t"))
123        self.assertFalse(m2._c.hasattr("ts"))
124        self.assertFalse(m2._c.hasattr("tt"))
125        output_f = m2.forward(input)
126        self.assertEqual(output_s, output_f)
127
128    def test_freeze_module_with_submodule(self):
129        class SubModule(nn.Module):
130            def __init__(self) -> None:
131                super().__init__()
132                self.a = 11
133                self.b = 2
134
135            def forward(self, x):
136                return self.a + self.b
137
138        class SubModule2(nn.Module):
139            def __init__(self) -> None:
140                super().__init__()
141                self.a = 12
142                self.b = 2
143
144            def forward(self, x):
145                self.b = 30
146                return self.a + self.b
147
148        class TestModule(nn.Module):
149            def __init__(self) -> None:
150                super().__init__()
151                self.sub1 = SubModule()
152                self.sub2 = SubModule2()
153                self.a = 3
154                self.b = 4
155
156            def forward(self, x):
157                self.b = 20
158                return self.sub1(x) + self.a + self.b + self.sub2(x)
159
160        m = torch.jit.script(TestModule())
161        m.eval()
162        input = torch.randn(2, 2)
163        output_s = m.forward(input)
164        mf = torch.jit.freeze(m)
165
166        # Check if frozen module looks as below:
167        # module m {
168        #   attributes {
169        #     sub2 = ...
170        #      b =
171        #   }
172        #   ...
173        #   submodule {
174        #     module m {
175        #       attributes {
176        #         sub2 = ...
177        #         b =
178        #       }
179        #       ...
180        #     }
181        #   }
182        # }
183        mf = mf._c
184        self.assertFalse(mf.hasattr("sub1"))
185        self.assertFalse(mf.hasattr("a"))
186        self.assertTrue(mf.hasattr("b"))
187        self.assertTrue(mf.hasattr("sub2"))
188        self.assertTrue(mf.sub2.hasattr("b"))  # verify b is preserved in sub2
189        self.assertFalse(mf.sub2.hasattr("a"))  # verify a is removed in sub2
190        output_f = mf.forward(input)
191        self.assertEqual(output_s, output_f)
192
193    def test_freeze_module_with_fork(self):
194        class SubModule(nn.Module):
195            def __init__(self) -> None:
196                super().__init__()
197                self.a = torch.ones(20, 20)
198                self.b = torch.ones(20, 20)
199
200            def forward(self, x):
201                return self.a * self.b + x
202
203        class TestModule(nn.Module):
204            def __init__(self) -> None:
205                super().__init__()
206                self.sub = SubModule()
207
208            def forward(self, x):
209                fut = torch.jit._fork(self.sub.forward, x)
210                y_hat = self.sub(x)
211                y = torch.jit._wait(fut)
212                return y_hat + y
213
214        m = torch.jit.script(TestModule())
215        m.eval()
216        input = torch.randn(20, 20)
217        output_s = m.forward(input)
218        mf = torch._C._freeze_module(m._c)
219
220        # Check if frozen module looks as below:
221        # module m {
222        #   attributes {
223        #   }
224        #   ...
225        #   submodule {
226        #   }
227        # }
228        self.assertFalse(mf.hasattr("a"))
229        self.assertFalse(mf.hasattr("b"))
230        output_f = mf.forward(input)
231        self.assertEqual(output_s, output_f)
232
233    def test_freeze_module_with_nested_fork(self):
234        class SubModule(nn.Module):
235            def __init__(self) -> None:
236                super().__init__()
237                self.a = torch.ones(20, 20)
238                self.b = torch.ones(20, 20)
239
240            def forward(self, x):
241                return self.a * self.b + x
242
243        class SubModule2(nn.Module):
244            def __init__(self) -> None:
245                super().__init__()
246                self.sub = SubModule()
247                self.c = torch.ones(20, 20)
248
249            def forward(self, x):
250                fut = torch.jit._fork(self.sub.forward, x)
251                y_hat = self.sub(x)
252                y = torch.jit._wait(fut)
253                return y_hat + y + self.c
254
255        class TestModule(nn.Module):
256            def __init__(self) -> None:
257                super().__init__()
258                self.sub = SubModule2()
259                self.d = 1
260
261            def forward(self, x):
262                fut = torch.jit._fork(self.sub.forward, x)
263                y_hat = self.sub(x)
264                y = torch.jit._wait(fut)
265                self.d = 2
266                return y_hat * y + self.d
267
268        m = torch.jit.script(TestModule())
269        m.eval()
270        input = torch.randn(20, 20)
271        output_s = m.forward(input)
272        mf = torch._C._freeze_module(m._c)
273        # Check if frozen module looks as below:
274        # module m {
275        #   attributes {
276        #   }
277        #   ...
278        #   submodule {
279        #   }
280        # }
281        self.assertFalse(mf.hasattr("a"))
282        self.assertFalse(mf.hasattr("b"))
283        self.assertFalse(mf.hasattr("c"))
284        self.assertTrue(mf.hasattr("d"))
285        output_f = mf.forward(input)
286        self.assertEqual(output_s, output_f)
287
288    def test_freeze_module_with_fork2(self):
289        @torch.jit.script
290        def foo(x):
291            return x * 2
292
293        class TestModule(nn.Module):
294            def __init__(self) -> None:
295                super().__init__()
296                self.a = torch.ones(20, 20)
297                self.b = torch.ones(20, 20)
298
299            def forward(self, x):
300                fut = torch.jit._fork(foo, self.a)
301                y_hat = foo(self.b)
302                y = torch.jit._wait(fut)
303                return y_hat + y
304
305        m = torch.jit.script(TestModule())
306        m.eval()
307        input = torch.randn(2, 2)
308        output_s = m.forward(input)
309        mf = torch._C._freeze_module(m._c)
310
311        # Check if frozen module looks as below:
312        # module m {
313        #   attributes {
314        #     self.a = ...
315        #     self.b = ..
316        #   }
317        #   ...
318        #   submodule {
319        #   }
320        # }
321        # TODO:  Although there are no mutation, the alias analysis
322        # conservatively assumes there is a mutation because attributes are
323        # passed to fork subgraph. both 'a' and 'b' are preserved.
324        self.assertTrue(mf.hasattr("a"))
325        self.assertFalse(mf.hasattr("b"))
326        output_f = mf.forward(input)
327        self.assertEqual(output_s, output_f)
328
329    def test_freeze_module_with_fork_calling_module_method(self):
330        @torch.jit.script
331        def foo(x, y):
332            return x * y
333
334        class TestModule(nn.Module):
335            def __init__(self) -> None:
336                super().__init__()
337                self.a = torch.ones(20, 20)
338                self.b = torch.ones(20, 20)
339
340            @torch.jit.export
341            def foo(self, x):
342                return x * self.a
343
344            @torch.jit.export
345            def bar(self, x):
346                return x * self.b
347
348            def forward(self, x):
349                fut = torch.jit._fork(self.foo, self.b)
350                y_hat = self.bar(self.a)
351                y = torch.jit._wait(fut)
352                return y_hat + y
353
354        m = torch.jit.script(TestModule())
355        m.eval()
356        input = torch.randn(2, 2)
357        output_s = m.forward(input)
358        mf = torch._C._freeze_module(m._c)
359        # Check if frozen module looks as below:
360        # module m {
361        #   attributes {
362        #     self.b = ..
363        #   }
364        #   ...
365        # TODO:  Although there are no mutation, the alias analysis
366        # conservatively assumes there is a mutation because attributes are
367        # passed to fork subgraph. 'b' is preserved.
368        self.assertFalse(mf.hasattr("a"))
369        self.assertTrue(mf.hasattr("b"))
370        output_f = mf.forward(input)
371        self.assertEqual(output_s, output_f)
372
373    def test_freeze_module_with_sharedclasstype(self):
374        class SubModule(nn.Module):
375            def __init__(self) -> None:
376                super().__init__()
377                self.a = torch.tensor([1.1])
378                self.b = torch.tensor([2.2])
379
380            def forward(self, x):
381                return self.a + self.b
382
383            @torch.jit.export
384            def modify_a(self, x):
385                self.a[0] += 10
386                return self.b
387
388            @torch.jit.export
389            def modify_b(self, x):
390                self.b[0] += 20
391                return self.a
392
393        class SubModule2(nn.Module):
394            def __init__(self) -> None:
395                super().__init__()
396                self.sub = SubModule()
397                self.b = torch.tensor([3.3])
398
399            def forward(self, x):
400                y = self.sub.modify_b(x)
401                return y + self.b
402
403        class TestModule(nn.Module):
404            def __init__(self) -> None:
405                super().__init__()
406                self.sub1 = SubModule()  # sub1 and sub2.sub shared same class type.
407                self.sub2 = SubModule2()
408                self.a = torch.tensor([4.4])
409
410            def forward(self, x):
411                z = self.sub1.modify_a(x)
412                return self.sub2(x) + z + self.a
413
414        m = torch.jit.script(TestModule())
415        m.eval()
416        input = torch.randn(2, 2)
417        output_s = m.forward(input)
418        mf = torch._C._freeze_module(m._c)
419
420        # Checking if  Frozen module looks as  below
421        # module mf {
422        #   attributes {
423        #     sub1 = ...
424        #     sub2 = ...
425        #   }
426        #   ...
427        #   submodules {
428        #     module sub1 {
429        #       attributes {
430        #         a = ...
431        #         b = ...
432        #       }
433        #       ...
434        #     }
435        #     module sub2 {
436        #       attributes {
437        #         sub = ...
438        #       }
439        #       ...
440        #       submodule {
441        #         module sub {
442        #           attributes {
443        #             a = ...
444        #             b = ...
445        #           }
446        #           ...
447        #         }
448        #       }
449        #     }
450        #   }
451        # }
452
453        self.assertTrue(mf.hasattr("sub1"))
454        self.assertTrue(mf.sub1.hasattr("a"))
455        self.assertTrue(mf.sub1.hasattr("b"))
456        self.assertFalse(mf.hasattr("a"))
457        self.assertTrue(mf.hasattr("sub2"))
458        self.assertTrue(mf.sub2.hasattr("sub"))
459        self.assertFalse(mf.sub2.hasattr("b"))
460        self.assertTrue(mf.sub2.sub.hasattr("a"))
461        self.assertTrue(mf.sub2.sub.hasattr("b"))
462        output_f = mf.forward(input)
463        self.assertEqual(output_s, output_f)
464
465    def test_freeze_module_with_nestedaliasing(self):
466        class SubModule(nn.Module):
467            def __init__(self) -> None:
468                super().__init__()
469                self.a = torch.tensor([1.1])
470                self.b = torch.tensor([2.2])
471
472            def forward(self, x):
473                return self.a + self.b
474
475            @torch.jit.export
476            def modify_a(self, x):
477                self.a[0] = 10
478                return self.b
479
480            @torch.jit.export
481            def modify_b(self, x):
482                self.b[0] = 20
483                return self.a
484
485        Sub = SubModule()
486
487        class SubModule2(nn.Module):
488            def __init__(self) -> None:
489                super().__init__()
490                self.sub = Sub  # aliasing
491
492            def forward(self, x):
493                return self.sub.a
494
495        class TestModule(nn.Module):
496            def __init__(self) -> None:
497                super().__init__()
498                self.sub1 = Sub  # aliasing
499                self.sub2 = SubModule2()
500
501            def forward(self, x):
502                z = self.sub1.modify_a(x)
503                return self.sub2(x) + z
504
505        m = torch.jit.script(TestModule())
506        m.eval()
507        mf = torch._C._freeze_module(m._c)
508        self.assertTrue(mf.hasattr("sub1"))
509        self.assertTrue(mf.sub1.hasattr("a"))
510        self.assertFalse(mf.sub1.hasattr("b"))
511        self.assertTrue(mf.hasattr("sub2"))
512        self.assertTrue(mf.sub2.hasattr("sub"))
513        self.assertTrue(
514            mf.sub2.sub.hasattr("a")
515        )  # Freezing detects that self.sub2.sub.a and self.sub1.a are alias
516        self.assertFalse(mf.sub2.sub.hasattr("b"))
517        input = torch.randn(2, 2)
518        output_s = m.forward(input)
519        output_f = mf.forward(input)
520        self.assertEqual(output_s, output_f)
521
522    # FIXME: JIT is not honoring aliasing. 'Sub' module is copied. As a result
523    # Eager and Script modules produce different output.
524    def test_freeze_module_with_nestedaliasingscalar(self):
525        class SubModule(nn.Module):
526            def __init__(self) -> None:
527                super().__init__()
528                self.a = 1.1
529                self.b = 2.2
530
531            def forward(self, x):
532                return self.a + self.b
533
534            @torch.jit.export
535            def modify_a(self, x):
536                self.a = 10.0
537                return self.b
538
539            @torch.jit.export
540            def modify_b(self, x):
541                self.b = 20.0
542                return self.a
543
544        Sub = SubModule()
545
546        class SubModule2(nn.Module):
547            def __init__(self) -> None:
548                super().__init__()
549                self.sub = Sub  # aliasing
550
551            def forward(self, x):
552                return self.sub.a
553
554        class TestModule(nn.Module):
555            def __init__(self) -> None:
556                super().__init__()
557                self.sub1 = Sub  # aliasing
558                self.sub2 = SubModule2()
559
560            def forward(self, x):
561                z = self.sub1.modify_a(x)
562                return self.sub2(x) + z
563
564        m = TestModule()
565        ms = torch.jit.script(m)
566        ms.eval()
567        mf = torch._C._freeze_module(ms._c)
568        self.assertTrue(mf.hasattr("sub1"))
569        self.assertTrue(mf.sub1.hasattr("a"))
570        self.assertFalse(mf.sub1.hasattr("b"))
571        # sub2 is fully folded becasue self.sub1 and self.sub2.sub are not alias (Scripting bug)
572        self.assertFalse(mf.hasattr("sub2"))
573        input = torch.randn(2, 2)
574        output = m.forward(input)
575        output_s = ms.forward(input)
576        output_f = mf.forward(input)
577        # Should be equal
578        self.assertNotEqual(output, output_s)
579        self.assertEqual(output_s, output_f)
580
581    def test_freeze_module_with_preserve_sub_module(self):
582        class SubModule(nn.Module):
583            def __init__(self) -> None:
584                super().__init__()
585                self.a = torch.tensor([1.1])
586                self.b = 2.2
587
588            def forward(self, x):
589                return self.a
590
591        class TestModule(nn.Module):
592            def __init__(self) -> None:
593                super().__init__()
594                self.sub1 = SubModule()  # aliasing
595                self.sub2 = SubModule()
596
597            def forward(self, x):
598                return self.sub2(x) + self.sub1(x)
599
600        m = TestModule()
601        ms = torch.jit.script(m)
602        ms.eval()
603        mf = torch._C._freeze_module(ms._c, ["sub1"])
604
605        # Test that 'sub1' is preserved entirely and 'sub2' is completely folded
606        self.assertTrue(mf.hasattr("sub1"))
607        self.assertTrue(mf.sub1.hasattr("a"))
608        self.assertTrue(mf.sub1.hasattr("b"))
609        self.assertFalse(mf.hasattr("sub2"))
610        input = torch.randn(2, 2)
611        output_s = ms.forward(input)
612        output_f = mf.forward(input)
613        self.assertEqual(output_s, output_f)
614
615    def test_freeze_module_with_preserve_sub_module_and_mutation(self):
616        class SubModule(nn.Module):
617            def __init__(self) -> None:
618                super().__init__()
619                self.a = torch.tensor([1.1])
620                self.b = 2.2
621
622            def forward(self, x):
623                self.a[0] = 3.3
624                return self.a
625
626        class TestModule(nn.Module):
627            def __init__(self) -> None:
628                super().__init__()
629                self.sub1 = SubModule()  # aliasing
630                self.sub2 = SubModule()
631
632            def forward(self, x):
633                return self.sub2(x) + self.sub1(x)
634
635        m = TestModule()
636        ms = torch.jit.script(m)
637        ms.eval()
638        mf = torch._C._freeze_module(ms._c, ["sub1"])
639
640        # Test that be both sub1 and sub1 are preserved and 'b' is preserved
641        # even if it is not used. To fulfill user request to preserve 'sub1'
642        self.assertTrue(mf.hasattr("sub1"))
643        self.assertTrue(mf.sub1.hasattr("a"))
644        self.assertTrue(mf.sub1.hasattr("b"))
645        self.assertTrue(mf.hasattr("sub2"))
646        self.assertTrue(mf.sub2.hasattr("a"))
647        self.assertTrue(mf.sub2.hasattr("b"))
648        input = torch.randn(2, 2)
649        output_s = ms.forward(input)
650        output_f = mf.forward(input)
651        self.assertEqual(output_s, output_f)
652
653    def test_freeze_module_with_helperfunction(self):
654        class SubModule(nn.Module):
655            def __init__(self) -> None:
656                super().__init__()
657                self.a = 11
658                self.b = 2
659
660            def forward(self, x):
661                return self.a + self.b
662
663        class TestModule(nn.Module):
664            def __init__(self) -> None:
665                super().__init__()
666                self.sub = SubModule()
667                self.a = 3
668                self.b = 4
669
670            def forward(self, x):
671                self.b = 20
672                return self._forward(x) + self.a + self.b
673
674            def _forward(self, x):
675                return self.sub(x)
676
677        m = torch.jit.script(TestModule())
678        m.eval()
679        input = torch.randn(2, 2)
680        mf = torch._C._freeze_module(m._c)
681        self.assertFalse(mf.hasattr("sub"))
682        self.assertFalse(mf.hasattr("a"))
683        self.assertTrue(mf.hasattr("b"))
684        with self.assertRaisesRegex(
685            AttributeError, "TestModule (.*) does not have a field with name '_forward'"
686        ):
687            mf._forward(x)  # noqa: F821
688
689    def test_freeze_module_with_inplace_mutable(self):
690        class FreezeMe(torch.jit.ScriptModule):
691            def __init__(self) -> None:
692                super().__init__()
693                self.a = [11, 22]
694
695            @torch.jit.script_method
696            def forward(self, x):
697                for i in range(3):
698                    self.a.append(i)
699                return self.a
700
701        m = FreezeMe()
702        m.eval()
703        m_f = torch._C._freeze_module(m._c)
704        self.assertTrue(m_f.hasattr("a"))
705        m.forward(torch.tensor([3]))
706        out = m_f.forward(torch.tensor([5]))
707        expected = [11, 22, 0, 1, 2, 0, 1, 2]
708        self.assertEqual(out, expected)
709
710    # Mutable attributes
711    def test_freeze_module_with_mutable_list(self):
712        class FreezeMe(nn.Module):
713            def __init__(self) -> None:
714                super().__init__()
715                self.a = [1, 2]
716
717            def forward(self, x):
718                return self.a
719
720        m = FreezeMe()
721        m.eval()
722        m.a.append(3)
723        m_s = torch.jit.script(m)
724        v = m_s.a
725        v.append(4)
726        m_s.a = v
727        m_s.eval()
728        m_f = torch._C._freeze_module(m_s._c)
729        # Post-freezing mutating m_s.a  does not affect m_f (m_f has its own copy).
730        v = m_s.a
731        v.append(5)
732        m_s.a = v
733        self.assertFalse(m_f.hasattr("a"))
734        out = m_f.forward(torch.tensor([5]))
735        expected = [1, 2, 3, 4]
736        self.assertEqual(out, expected)
737
738    def test_freeze_module_with_mutable_dict(self):
739        class FreezeMe(nn.Module):
740            def __init__(self) -> None:
741                super().__init__()
742                self.a = {"layer": "4"}
743
744            def forward(self, x):
745                return self.a
746
747            @torch.jit.export
748            def modify_a(self, x):
749                self.a["layer"] = self.a["layer"] + "1"
750                return self.a
751
752        m = FreezeMe()
753        m.eval()
754        m.a["layer2"] = "3"
755        m_s = torch.jit.script(m)
756        t = torch.tensor(5)
757        m_s.modify_a(t)
758        m_s.eval()
759        m_f = torch._C._freeze_module(m_s._c)
760        m.a["layer2"] += "2"
761        m_s.modify_a(t)
762        self.assertFalse(m_f.hasattr("a"))
763        out = m_f.forward(t)
764        expected = {"layer": "411", "layer2": "3"}
765        self.assertEqual(out, expected)
766
767    def test_freeze_module_with_mutable_tensor(self):
768        class FreezeMe(nn.Module):
769            def __init__(self) -> None:
770                super().__init__()
771                self.a = torch.tensor([1.0, 2.0, 3.0])
772
773            def forward(self, x):
774                return self.a
775
776        m = FreezeMe()
777        m_s = torch.jit.script(m)
778        m_s.a[1] += 3.0
779        m_s.eval()
780        m_f = torch._C._freeze_module(m_s._c)
781        # Post-freezing tensor attribute mutations affect m_f.
782        # FIXME: deep copy all folded attributes so that m_f has full ownership.
783        m_s.a[0] += 5.0
784        self.assertFalse(m_f.hasattr("a"))
785        out = m_f.forward(torch.tensor([5]))
786        expected = [6.0, 5.0, 3.0]
787        self.assertEqual(out, expected)
788
789    def test_freeze_module_with_tuple(self):
790        class FreezeMe(nn.Module):
791            def __init__(self) -> None:
792                super().__init__()
793                self.a = (torch.tensor([1, 2, 3, 4, 5, 6]), "hi")
794
795            def forward(self, x):
796                if x[0] == 2.0:
797                    self.a[0][0] = 10
798                return self.a[0].sum()
799
800        m = FreezeMe()
801        m_s = torch.jit.script(m)
802        m_s.eval()
803        inp = torch.tensor([2.0])
804        expected = m_s.forward(inp)
805        m_s.a[0][0] = 1
806        m_f = torch._C._freeze_module(m_s._c)
807        self.assertFalse(m_f.hasattr("a"))
808        out = m_f.forward(inp)
809        self.assertEqual(out, expected)
810
811    def test_freeze_module_with_tensor(self):
812        class FreezeMe(nn.Module):
813            def __init__(self) -> None:
814                super().__init__()
815                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
816
817            def forward(self, x):
818                x = self.a.view(2, 3)
819                x[0][0] += 10
820                return self.a.sum()
821
822        m = FreezeMe()
823        m_s = torch.jit.script(m)
824        m_s.eval()
825        inp = torch.tensor([5])
826        expected = m_s.forward(inp)
827        m_f = torch._C._freeze_module(m_s._c)
828        self.assertTrue(m_f.hasattr("a"))
829        m_f.a[0] -= 10
830        out = m_f.forward(inp)
831        self.assertEqual(out, expected)
832
833    def test_freeze_module_with_list(self):
834        class FreezeMe(nn.Module):
835            def __init__(self) -> None:
836                super().__init__()
837                self.a = [torch.tensor([1, 2, 3, 4, 5, 6])]
838
839            def forward(self, x):
840                self.a[0][1] += 10
841                return self.a[0].sum()
842
843        m = FreezeMe()
844        m_s = torch.jit.script(m)
845        m_s.eval()
846        inp = torch.tensor([5])
847        expected = m_s.forward(inp)
848        m_s.a[0][1] -= 10
849        m_f = torch._C._freeze_module(m_s._c)
850        self.assertFalse(m_f.hasattr("a"))
851        out = m_f.forward(inp)
852        self.assertEqual(out, expected)
853
854    def test_freeze_module_with_aliased_tensor_attr(self):
855        class FreezeMe(nn.Module):
856            def __init__(self) -> None:
857                super().__init__()
858                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
859                self.b = self.a.view(2, 3)
860
861            def forward(self, x):
862                self.b[1] += 10
863                return self.a.sum()
864
865        m = FreezeMe()
866        m_s = torch.jit.script(m)
867        m_s.eval()
868        m_f = torch._C._freeze_module(m_s._c)
869        self.assertTrue(m_f.hasattr("a"))
870        inp = torch.tensor([5])
871        out = m_f.forward(inp)
872        expected = torch.tensor(51)  # 1+2+3+14+15+16
873        self.assertEqual(out, expected)
874
875    def test_freeze_module_with_aliased_tensor_attr2(self):
876        class FreezeMe(nn.Module):
877            def __init__(self) -> None:
878                super().__init__()
879                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
880                self.b = {"layer": ([self.a.view(2, 3), torch.tensor([10])], 20)}
881                self.c = ([self.a.view(2, 3), torch.tensor([10])], 20)
882                self.d = (self.a.view(2, 3), 20)
883
884            def forward(self, x):
885                self.d[0][0] += 10
886                return self.a.sum()
887
888        m = FreezeMe()
889        m_s = torch.jit.script(m)
890        m_s.eval()
891        inp = torch.tensor([5])
892        expected = m_s.forward(inp)
893        with self.assertRaisesRegex(
894            RuntimeError, "module contains attributes values that overlaps"
895        ):
896            m_f = torch._C._freeze_module(m_s._c)
897
898    def test_freeze_module_with_aliased_tensor_attr3(self):
899        class FreezeMe(nn.Module):
900            def __init__(self) -> None:
901                super().__init__()
902                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
903                self.b = [self.a, torch.tensor([10])]
904
905            def forward(self, x):
906                self.a[1] += 10
907                return self.b[0].sum()
908
909        m = FreezeMe()
910        m_s = torch.jit.script(m)
911        m_s.eval()
912        inp = torch.tensor([5])
913        expected = m_s.forward(inp)
914        m_f = torch._C._freeze_module(m_s._c)
915        self.assertTrue(m_f.hasattr("a"))
916        self.assertTrue(m_f.hasattr("b"))
917        out = m_f.forward(inp)
918        expected += 10  # account for  self.a += 10.
919        self.assertEqual(out, expected)
920
921    def test_freeze_module_with_aliased_tensor_attr4(self):
922        class FreezeMe(nn.Module):
923            def __init__(self) -> None:
924                super().__init__()
925                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
926                self.b = [self.a, torch.tensor([10])]
927
928            def forward(self, x):
929                self.b[0][0] += 10
930                return self.a.sum()
931
932        m = FreezeMe()
933        m_s = torch.jit.script(m)
934        m_s.eval()
935        inp = torch.tensor([5])
936        expected = m_s.forward(inp)
937        m_s.a[0] -= 10
938        with self.assertRaisesRegex(
939            RuntimeError, "module contains attributes values that overlaps"
940        ):
941            m_f = torch._C._freeze_module(m_s._c)
942
943    def test_freeze_module_with_overlapping_attrs(self):
944        a = torch.tensor([1, 2, 3, 4, 5, 6])
945
946        class FreezeMe(nn.Module):
947            def __init__(self) -> None:
948                super().__init__()
949                self.b = [a.view(3, 2), torch.tensor([10])]
950                self.c = (20, a.view(2, 3))
951
952            def forward(self, x):
953                self.b[0][0] += 10
954                return self.c[1].sum()
955
956        m = FreezeMe()
957        m_s = torch.jit.script(m)
958        m_s.eval()
959        inp = torch.tensor([5])
960        expected = m_s.forward(inp)
961        a[0] -= 10
962        with self.assertRaisesRegex(
963            RuntimeError, "module contains attributes values that overlaps"
964        ):
965            m_f = torch._C._freeze_module(m_s._c)
966
967    def test_freeze_module_with_aliased_attr(self):
968        class FreezeMe(nn.Module):
969            def __init__(self) -> None:
970                super().__init__()
971                self.a = [1, 2, 3, 4, 5, 6]
972                self.b = self.a
973                self.c = (self.a, 10)
974
975            def forward(self, x):
976                self.b[1] += 10
977                return str(self.a) + str(self.c)
978
979        m = FreezeMe()
980        m_s = torch.jit.script(m)
981        m_s.eval()
982        m_f = torch._C._freeze_module(m_s._c)
983        # FIXME: It should be assertTrue. Currently scripting is making a copy for setting self.b (see #33034)
984        self.assertFalse(m_f.hasattr("a"))
985        self.assertFalse(m_f.hasattr("c"))
986        inp = torch.tensor([5])
987        out = m_f.forward(inp)
988        expected = m_s.forward(inp)
989        self.assertEqual(out, expected)
990
991    # Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
992    # In this example, 'a' is not mutated. However, we do not track which sub
993    # values of a composite ivalue is mutated.
994    def test_freeze_module_with_aliased_attr2(self):
995        class FreezeMe(nn.Module):
996            def __init__(self) -> None:
997                super().__init__()
998                self.a = [1, 2, 3, 4, 5, 6]
999                self.b = ([11], [10])
1000
1001            def forward(self, x):
1002                v = self.a
1003                self.b = (v, [12])
1004                v2 = self.b[1]
1005                v2.append(7)
1006                return str(v) + str(v2)
1007
1008        m = FreezeMe()
1009        m_s = torch.jit.script(m)
1010        m_s.eval()
1011        m_f = torch._C._freeze_module(m_s._c)
1012        self.assertTrue(m_f.hasattr("a"))
1013        inp = torch.tensor([5])
1014        out = m_f.forward(inp)
1015        expected = m.forward(inp)
1016        self.assertEqual(out, expected)
1017
1018    def test_freeze_module_with_aliased_attr3(self):
1019        class FreezeMe(nn.Module):
1020            def __init__(self) -> None:
1021                super().__init__()
1022                self.a = [1, 2, 3, 4, 5, 6]
1023                self.b = ([11], [10])
1024
1025            def forward(self, x):
1026                v = self.a
1027                v2 = (v, [12])
1028                v3 = v2[0]
1029                v3.append(7)
1030                return str(self.a)
1031
1032        m = FreezeMe()
1033        m_s = torch.jit.script(m)
1034        m_s.eval()
1035        m_f = torch._C._freeze_module(m_s._c)
1036        self.assertTrue(m_f.hasattr("a"))
1037        inp = torch.tensor([5])
1038        out = m_f.forward(inp)
1039        expected = m.forward(inp)
1040        self.assertEqual(out, expected)
1041
1042    def test_freeze_module_return_self(self):
1043        class FreezeMe(nn.Module):
1044            def __init__(self) -> None:
1045                super().__init__()
1046                self.a = torch.tensor([1.0, 2.0, 3.0])
1047
1048            def forward(self, x):
1049                return self
1050
1051        m = FreezeMe()
1052        m_s = torch.jit.script(m)
1053        m_s.eval()
1054        with self.assertRaisesRegex(
1055            RuntimeError, "attempted to freeze a module that return itself"
1056        ):
1057            m_f = torch._C._freeze_module(m_s._c)
1058
1059    def test_freeze_module_inlining(self):
1060        @torch.jit.script  # noqa: B903
1061        class Obj:  # noqa: B903
1062            def __init__(self, x: int, y: int):
1063                self.x = x
1064                self.y = y
1065
1066        class Mod(nn.Module):
1067            def __init__(self) -> None:
1068                super().__init__()
1069                self.obj = Obj(2, 3)
1070
1071            def forward(self, i: int):
1072                print(self.obj)
1073                return i
1074
1075        mod = torch.jit.freeze(torch.jit.script(Mod().eval()))
1076        obj = mod.graph.findNode("prim::Constant")
1077        self.assertTrue(torch._C._jit_object_is_non_holding(obj))
1078
1079        buffer = io.BytesIO()
1080        torch.jit.save(mod, buffer)
1081        buffer.seek(0)
1082
1083        loaded = torch.jit.load(buffer)
1084        obj = mod.graph.findNode("prim::Constant")
1085        self.assertTrue(torch._C._jit_object_is_non_holding(obj))
1086
1087    def test_freeze_module_return_sub_module(self):
1088        class FreezeMe(nn.Module):
1089            def __init__(self) -> None:
1090                super().__init__()
1091                self.conv1 = nn.Conv2d(1, 32, 3, 1)
1092
1093            def forward(self, x):
1094                return self.conv1
1095
1096        m = FreezeMe()
1097        m_s = torch.jit.script(m)
1098        m_s.eval()
1099        m_f = torch._C._freeze_module(m_s._c)
1100        self.assertTrue(m_f.hasattr("conv1"))
1101
1102    def test_freeze_module_no_forward(self):
1103        class FreezeMe(nn.Module):
1104            def __init__(self) -> None:
1105                super().__init__()
1106                self.lin = nn.Linear(10, 1)
1107
1108            @torch.jit.export
1109            def foo(self, x):
1110                return self.lin(x)
1111
1112        m = FreezeMe()
1113        m_s = torch.jit.script(m)
1114        m_s.eval()
1115        m_f = torch._C._freeze_module(m_s._c, preservedAttrs=["foo"])
1116        input = torch.ones(10)
1117        self.assertEqual(m_s.foo(input), m_f.foo(input))
1118
1119    def test_freeze_no_forward(self):
1120        class FreezeMe(nn.Module):
1121            def __init__(self) -> None:
1122                super().__init__()
1123                self.lin = nn.Linear(10, 1)
1124
1125            @torch.jit.export
1126            def foo(self, x):
1127                return self.lin(x)
1128
1129        m = FreezeMe()
1130        m_s = torch.jit.script(m)
1131        m_s.eval()
1132        m_f = torch.jit.freeze(m_s, preserved_attrs=["foo"])
1133        input = torch.ones(10)
1134        self.assertEqual(m_s.foo(input), m_f.foo(input))
1135
1136    def test_freeze_module_in_training_mode(self):
1137        class Net(nn.Module):
1138            def __init__(self) -> None:
1139                super().__init__()
1140                self.conv1 = nn.Conv2d(1, 32, 3, 1)
1141                self.conv2 = nn.Conv2d(32, 64, 3, 1)
1142                self.dropout1 = nn.Dropout2d(0.25)
1143                self.dropout2 = nn.Dropout2d(0.5)
1144                self.fc1 = nn.Linear(9216, 128)
1145                self.fc2 = nn.Linear(128, 10)
1146
1147            def forward(self, x):
1148                x = self.conv1(x)
1149                x = nn.functional.relu(x)
1150                x = self.conv2(x)
1151                x = nn.functional.max_pool2d(x, 2)
1152                x = self.dropout1(x)
1153                x = torch.flatten(x, 1)
1154                x = self.fc1(x)
1155                x = nn.functional.relu(x)
1156                x = self.dropout2(x)
1157                x = self.fc2(x)
1158                output = nn.functional.log_softmax(x, dim=1)
1159                return output
1160
1161        model = torch.jit.script(Net())
1162        model.train()
1163        mTrain_freezed = torch._C._freeze_module(model._c)
1164        # verify mTrain_freezed looks exactly as:
1165        # module {
1166        #   attributes {
1167        #     conv1 = ...
1168        #     conv2 = ...
1169        #     dropout1 = ...
1170        #     dropout2 = ...
1171        #     fc1 = ...
1172        #     fc2 = ...
1173        #   }
1174        #   ...
1175        #   submodules {
1176        #     module conv1 {
1177        #       attributes {
1178        #          weight = ...
1179        #          bias = ...
1180        #       }
1181        #       ...
1182        #     }
1183        #     module conv2 {
1184        #       attributes {
1185        #          weight = ...
1186        #          bias = ...
1187        #       }
1188        #       ...
1189        #     }
1190        #     module dropout1 {
1191        #       attributes {
1192        #          training = ...
1193        #       }
1194        #       ...
1195        #     }
1196        #     module dropout2 {
1197        #       attributes {
1198        #          training = ...
1199        #       }
1200        #       ...
1201        #     }
1202        #     module fc1 {
1203        #       attributes {
1204        #          weight = ...
1205        #          bias = ...
1206        #       }
1207        #       ...
1208        #     }
1209        #     module fc2 {
1210        #       attributes {
1211        #          weight = ...
1212        #          bias = ...
1213        #       }
1214        #       ...
1215        #     }
1216        self.assertFalse(mTrain_freezed.hasattr("training"))
1217        self.assertTrue(mTrain_freezed.hasattr("conv1"))
1218        self.assertFalse(mTrain_freezed.conv1.hasattr("training"))
1219        self.assertTrue(mTrain_freezed.conv1.hasattr("weight"))
1220        self.assertTrue(mTrain_freezed.conv1.hasattr("bias"))
1221        self.assertTrue(mTrain_freezed.hasattr("conv2"))
1222        self.assertFalse(mTrain_freezed.conv2.hasattr("training"))
1223        self.assertTrue(mTrain_freezed.conv2.hasattr("weight"))
1224        self.assertTrue(mTrain_freezed.conv2.hasattr("bias"))
1225        self.assertTrue(mTrain_freezed.hasattr("dropout1"))
1226        self.assertTrue(mTrain_freezed.dropout1.hasattr("training"))
1227        self.assertTrue(mTrain_freezed.hasattr("dropout2"))
1228        self.assertTrue(mTrain_freezed.dropout2.hasattr("training"))
1229        self.assertTrue(mTrain_freezed.hasattr("fc1"))
1230        self.assertTrue(mTrain_freezed.fc1.hasattr("weight"))
1231        self.assertTrue(mTrain_freezed.fc1.hasattr("bias"))
1232        self.assertTrue(mTrain_freezed.hasattr("fc2"))
1233        self.assertTrue(mTrain_freezed.fc2.hasattr("weight"))
1234        self.assertTrue(mTrain_freezed.fc2.hasattr("bias"))
1235        model.eval()
1236        mEval_freezed = torch._C._freeze_module(model._c)
1237        self.assertFalse(mEval_freezed.hasattr("conv1"))
1238        self.assertFalse(mEval_freezed.hasattr("conv2"))
1239        self.assertFalse(mEval_freezed.hasattr("dropout1"))
1240        self.assertFalse(mEval_freezed.hasattr("training"))
1241        self.assertFalse(mEval_freezed.hasattr("fc1"))
1242        self.assertFalse(mEval_freezed.hasattr("dropout2"))
1243        self.assertFalse(mEval_freezed.hasattr("fc2"))
1244        with self.assertRaisesRegex(
1245            AttributeError, "does not have a field with name 'state_dict'"
1246        ):
1247            print(mEval_freezed.state_dict())
1248        buffer = io.BytesIO()
1249        torch.jit.save(mEval_freezed, buffer)
1250        buffer.seek(0)
1251        m = torch.jit.load(buffer)
1252        FileCheck().check_not("GetAttr[name=").run(m._c._get_method("forward").graph)
1253        m2 = torch._C._freeze_module(model._c, preserveParameters=True)
1254        self.assertTrue(m2.hasattr("conv1"))
1255        self.assertTrue(m2.hasattr("conv2"))
1256        self.assertFalse(m2.hasattr("dropout1"))
1257        self.assertFalse(m2.hasattr("training"))
1258        self.assertTrue(m2.hasattr("fc1"))
1259        self.assertFalse(m2.hasattr("dropout2"))
1260        self.assertTrue(m2.hasattr("fc2"))
1261
1262    def test_freeze_module_detach_gradient(self):
1263        mod = nn.Conv2d(8, 3, 4, 2, 1)
1264        self.assertTrue(mod.weight.requires_grad)
1265        smod = torch.jit.script(mod)
1266        smod.eval()
1267        fmod = torch._C._freeze_module(smod._c)
1268        self.assertTrue(mod.weight.requires_grad)
1269        self.assertTrue(smod.weight.requires_grad)
1270        self.assertFalse(fmod.hasattr("weight"))
1271        inp = torch.ones(1, 8, 32, 32)
1272        out1 = fmod.forward(inp)
1273        # FIXME: frozen module mutated from outside (original module).
1274        with torch.no_grad():
1275            smod.weight[0, 0, 0, 0] += 100.0
1276        out2 = fmod.forward(inp)
1277        out3 = smod(inp)
1278        self.assertNotEqual(out1, out2)
1279        self.assertEqual(out2, out3)
1280
1281    def test_freeze_module_with_user_preserved_attr(self):
1282        class Module(nn.Module):
1283            def __init__(self) -> None:
1284                super().__init__()
1285                self.a = torch.tensor([1.1])
1286                self.b = torch.tensor([2.2])
1287
1288            def forward(self, x):
1289                return self.a + self.b
1290
1291        m = torch.jit.script(Module())
1292        m.eval()
1293        fm = torch._C._freeze_module(m._c, ["a"])
1294        # Attribute "a" is preserved
1295        self.assertTrue(fm.hasattr("a"))
1296        self.assertFalse(fm.hasattr("b"))
1297
1298    def test_freeze_module_with_user_preserved_method(self):
1299        class Module(nn.Module):
1300            def __init__(self) -> None:
1301                super().__init__()
1302                self.a = torch.tensor([1.1])
1303                self.b = torch.tensor([2.2])
1304
1305            def forward(self, x):
1306                return self.a + self.b
1307
1308            @torch.jit.export
1309            def modify_a(self, x):
1310                self.a[0] += 10
1311                return self.b
1312
1313            @torch.jit.export
1314            def modify_b(self, x):
1315                self.b[0] += 20
1316                return self.a
1317
1318        m = torch.jit.script(Module())
1319        m.eval()
1320        fm = torch._C._freeze_module(m._c, ["modify_a"])
1321        # Both attribute "a" and method "modify_a" are preserved
1322        self.assertTrue(fm.hasattr("a"))
1323        self.assertFalse(fm.hasattr("b"))
1324        input = torch.randn(2, 2)
1325        expected = m.forward(input)
1326        out = fm.forward(input)
1327        self.assertEqual(out, expected)
1328
1329    def test_freeze_module_with_user_preserved_method2(self):
1330        class Module(nn.Module):
1331            def __init__(self) -> None:
1332                super().__init__()
1333                self.a = torch.tensor([1.1])
1334                self.b = torch.tensor([2.2])
1335
1336            def forward(self, x):
1337                self.b += 10
1338                return self.a + self.b
1339
1340            @torch.jit.export
1341            def modify_a(self, x):
1342                self.a[0] += 10
1343                return self.b + self.a
1344
1345        m = torch.jit.script(Module())
1346        m.eval()
1347        fm = torch._C._freeze_module(m._c, ["modify_a"])
1348        FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
1349        FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)
1350
1351    def test_freeze_module_with_user_preserved_attribute_on_submodule(self):
1352        class SubModule(nn.Module):
1353            def __init__(self) -> None:
1354                super().__init__()
1355                self.a = 1
1356                self.b = 2
1357
1358            def forward(self):
1359                return self.a + self.b
1360
1361        class Module(nn.Module):
1362            def __init__(self) -> None:
1363                super().__init__()
1364                self.sub1 = SubModule()
1365                self.sub2 = SubModule()
1366
1367            def forward(self):
1368                return self.sub1() + self.sub2()
1369
1370        m = torch.jit.script(Module())
1371        m.eval()
1372        m = torch.jit.freeze(m, preserved_attrs=["sub1.a", "sub2.a"])
1373        fm = m._c
1374
1375        self.assertTrue(fm.hasattr("sub1"))
1376        self.assertTrue(fm.sub1.hasattr("a"))
1377        self.assertFalse(fm.sub1.hasattr("b"))
1378        self.assertTrue(fm.hasattr("sub2"))
1379        self.assertTrue(fm.sub2.hasattr("a"))
1380        self.assertFalse(fm.sub2.hasattr("b"))
1381        self.assertEqual(m(), 6)
1382        m.sub1.a += 1
1383        self.assertEqual(m(), 7)
1384
1385    def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self):
1386        class SubModule(nn.Module):
1387            def __init__(self) -> None:
1388                super().__init__()
1389                self.a = 1
1390                self.b = 2
1391
1392            def forward(self):
1393                return self.a + self.b
1394
1395            @torch.jit.export
1396            def method_a(self):
1397                return 42
1398
1399        class Module(nn.Module):
1400            def __init__(self) -> None:
1401                super().__init__()
1402                self.sub = SubModule()
1403
1404            def forward(self):
1405                return 1
1406
1407        m = torch.jit.script(Module())
1408        m.eval()
1409        fm = torch.jit.freeze(m, preserved_attrs=["sub.a", "sub.method_a"])._c
1410
1411        self.assertTrue(fm.hasattr("sub"))
1412        self.assertTrue(fm.sub.hasattr("a"))
1413        self.assertFalse(fm.sub.hasattr("b"))
1414        self.assertTrue(fm.sub._has_method("method_a"))
1415
1416    def test_freeze_module_with_user_preserved_method_on_submodule(self):
1417        class SubModule(nn.Module):
1418            def forward(self, x):
1419                return self.method_a(x) + self.method_b(x)
1420
1421            def method_a(self, x):
1422                return x * x
1423
1424            def method_b(self, x):
1425                return x + x
1426
1427        class Module(nn.Module):
1428            def __init__(self) -> None:
1429                super().__init__()
1430                self.sub = SubModule()
1431
1432            def forward(self, x):
1433                return self.sub(x)
1434
1435        m = torch.jit.script(Module())
1436        m.eval()
1437        fm = torch.jit.freeze(m, preserved_attrs=["sub.method_a"])._c
1438
1439        self.assertTrue(fm.hasattr("sub"))
1440        self.assertTrue(fm.sub._has_method("method_a"))
1441        self.assertFalse(fm.sub._has_method("method_b"))
1442
1443    @skipIfNoFBGEMM
1444    def test_module_with_shared_type_instances(self):
1445        class Child(nn.Module):
1446            def __init__(self) -> None:
1447                super().__init__()
1448                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
1449
1450            def forward(self, x):
1451                x = self.conv1(x)
1452                return x
1453
1454        class Parent(nn.Module):
1455            def __init__(self) -> None:
1456                super().__init__()
1457                self.quant = torch.ao.quantization.QuantStub()
1458                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
1459                self.child = Child()
1460                self.child2 = Child()
1461                self.dequant = torch.ao.quantization.DeQuantStub()
1462
1463            def forward(self, x):
1464                x = self.quant(x)
1465                x = self.conv1(x)
1466                x = self.child(x)
1467                x = self.child2(x)
1468                x = self.dequant(x)
1469                return x
1470
1471        def _static_quant(model):
1472            qModel = torch.ao.quantization.QuantWrapper(model)
1473            qModel.qconfig = torch.ao.quantization.default_qconfig
1474            torch.ao.quantization.prepare(qModel, inplace=True)
1475            qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32))
1476            torch.ao.quantization.convert(qModel, inplace=True)
1477            return model
1478
1479        with override_quantized_engine("fbgemm"):
1480            data = torch.randn(4, 1, 4, 4, dtype=torch.float32)
1481            m = Parent().to(torch.float32)
1482            m = _static_quant(m)
1483            m = torch.jit.script(m)
1484            m.eval()
1485            torch._C._jit_pass_inline(m.graph)
1486            m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c))
1487            # Earlier bug resulted in _packed_params set to false.
1488            FileCheck().check_not("_packed_params = False").run(
1489                m_frozen._c.dump_to_str(True, True, False)
1490            )
1491
1492            m_res = m(data)
1493            # It used to segfault while running frozen module.
1494            m_frozen_res = m_frozen(data)
1495            self.assertEqual(m_res, m_frozen_res)
1496
1497    def test_module_getattr_indirection(self):
1498        @torch.jit.script
1499        class ValHolder:
1500            def __init__(self, val: int):
1501                self.val: int = val
1502
1503        class Mod(nn.Module):
1504            def __init__(self) -> None:
1505                super().__init__()
1506                self.mod1 = ValHolder(1)
1507                self.mod2 = ValHolder(2)
1508
1509            def forward(self, cond: bool):
1510                if cond:
1511                    mod = self.mod1
1512                else:
1513                    mod = self.mod2
1514                return mod.val
1515
1516        mod = Mod()
1517        mod.eval()
1518        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
1519        mod_eager = Mod()
1520        self.assertEqual(mod_eager(True), frozen_mod(True))
1521        self.assertEqual(mod_eager(False), frozen_mod(False))
1522
1523    def test_freeze_module_with_non_static_module_container_index(self):
1524        """
1525        Test that Modules containing non-static ModuleDict or ModuleList
1526        indexing cannot be frozen.
1527        """
1528
1529        @torch.jit.interface
1530        class ModuleInterface(torch.nn.Module):
1531            def forward(self, inp: Any) -> Any:
1532                pass
1533
1534        class ImplementsInterface(torch.nn.Module):
1535            def forward(self, inp: Any) -> Any:
1536                if isinstance(inp, torch.Tensor):
1537                    return torch.max(inp, dim=0)
1538
1539                return inp
1540
1541        class ModWithDict(torch.nn.Module):
1542            def __init__(self) -> None:
1543                super().__init__()
1544                self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
1545
1546            def forward(self, x: torch.Tensor, key: str) -> Any:
1547                value: ModuleInterface = self.d[key]
1548                return value.forward(x)
1549
1550        m = torch.jit.script(ModWithDict())
1551        m.eval()
1552        with self.assertRaisesRegex(
1553            RuntimeError,
1554            "Freezing modules containing prim::ModuleContainerIndex is not supported",
1555        ):
1556            mf = torch._C._freeze_module(m._c)
1557
1558        class ModWithList(torch.nn.Module):
1559            def __init__(self) -> None:
1560                super().__init__()
1561                self.l = torch.nn.ModuleList([ImplementsInterface()])
1562
1563            def forward(self, x: torch.Tensor, idx: int) -> Any:
1564                value: ModuleInterface = self.l[idx]
1565                return value.forward(x)
1566
1567        m = torch.jit.script(ModWithList())
1568        m.eval()
1569        with self.assertRaisesRegex(
1570            RuntimeError,
1571            "Freezing modules containing prim::ModuleContainerIndex is not supported",
1572        ):
1573            mf = torch._C._freeze_module(m._c)
1574
1575    def test_freeze_with_interface_mutable(self):
1576        @torch.jit.interface
1577        class ModuleInterface(torch.nn.Module):
1578            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1579                pass
1580
1581        class ImplementsInterface(torch.nn.Module):
1582            def __init__(self) -> None:
1583                super().__init__()
1584                self.sum = torch.zeros((2, 2))
1585
1586            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1587                self.sum += inp.relu()
1588                return self.sum
1589
1590        class WrapperModule(torch.nn.Module):
1591            impl: ModuleInterface
1592
1593            def __init__(self) -> None:
1594                super().__init__()
1595                self.impl = ImplementsInterface()
1596
1597            def forward(self, x: torch.Tensor) -> torch.Tensor:
1598                return self.impl.forward(x)
1599
1600        m = torch.jit.script(WrapperModule())
1601        m.eval()
1602        m_frozen = torch.jit.freeze(m)
1603
1604        x = torch.rand((2, 2))
1605
1606        m_frozen(x)
1607        self.assertEqual(m_frozen.impl.sum, x.relu())
1608
1609    def test_freeze_with_swapping_interfaces(self):
1610        @torch.jit.interface
1611        class ModuleInterface(torch.nn.Module):
1612            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1613                pass
1614
1615        class Implementation1(torch.nn.Module):
1616            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1617                return inp.relu()
1618
1619        class Implementation2(torch.nn.Module):
1620            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1621                return inp.sin()
1622
1623        class WrapperModule(torch.nn.Module):
1624            impl: ModuleInterface
1625
1626            def __init__(self) -> None:
1627                super().__init__()
1628                self.option1 = Implementation1()
1629                self.option2 = Implementation2()
1630                self.impl = self.option1
1631                self.idx = 0
1632
1633            def forward(self, x: torch.Tensor) -> torch.Tensor:
1634                self.idx += 1
1635                if self.idx % 2 == 1:
1636                    self.impl = self.option1
1637                else:
1638                    self.impl = self.option2
1639                return self.impl(x)
1640
1641        m = torch.jit.script(WrapperModule())
1642        m.eval()
1643        with self.assertRaisesRegex(
1644            RuntimeError, "Freezing does not support SetAttr on an interface type"
1645        ):
1646            m_frozen = torch.jit.freeze(m)
1647
1648    def test_freeze_recursive_interfaces(self):
1649        @torch.jit.interface
1650        class InnerInterface(torch.nn.Module):
1651            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1652                pass
1653
1654        @torch.jit.interface
1655        class OuterInterface(torch.nn.Module):
1656            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1657                pass
1658
1659        class InnerImpl(torch.nn.Module):
1660            def __init__(self) -> None:
1661                super().__init__()
1662                self.x = torch.ones((2, 2))
1663
1664            def forward(self, inp):
1665                return inp.cos() * self.x
1666
1667        class OuterImpl(torch.nn.Module):
1668            inner_impl: InnerInterface
1669
1670            def __init__(self) -> None:
1671                super().__init__()
1672                self.inner_impl = InnerImpl()
1673
1674            def forward(self, inp):
1675                return inp.relu() + self.inner_impl(inp.sin())
1676
1677        class WrapperModule(torch.nn.Module):
1678            outer_impl: OuterInterface
1679
1680            def __init__(self) -> None:
1681                super().__init__()
1682                self.outer_impl = OuterImpl()
1683
1684            def forward(self, inp):
1685                return self.outer_impl(inp) + inp
1686
1687        m = WrapperModule()
1688        x = torch.rand((2, 2))
1689        expected = m(x)
1690
1691        m_s = torch.jit.script(m)
1692        m_s.eval()
1693        m_s = torch.jit.freeze(m_s)
1694        actual = m_s(x)
1695
1696        self.assertEqual(expected, actual)
1697
1698    def test_freeze_recursive_interfaces_with_reassignment(self):
1699        @torch.jit.interface
1700        class InnerInterface(torch.nn.Module):
1701            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1702                pass
1703
1704        @torch.jit.interface
1705        class OuterInterface(torch.nn.Module):
1706            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1707                pass
1708
1709        class InnerImpl1(torch.nn.Module):
1710            def __init__(self) -> None:
1711                super().__init__()
1712                self.x = torch.ones((2, 2))
1713
1714            def forward(self, inp):
1715                return inp.cos() * self.x
1716
1717        class InnerImpl2(torch.nn.Module):
1718            def __init__(self) -> None:
1719                super().__init__()
1720                self.x = torch.ones((2, 2)) * 2
1721
1722            def forward(self, inp):
1723                return inp.sin() / self.x
1724
1725        class OuterImpl(torch.nn.Module):
1726            inner_impl: InnerInterface
1727
1728            def __init__(self) -> None:
1729                super().__init__()
1730                self.inner_impl = InnerImpl1()
1731                self.impl1 = InnerImpl1()
1732                self.impl2 = InnerImpl1()
1733                self.idx = 0
1734
1735            def forward(self, inp):
1736                self.idx += 1
1737                if self.idx % 2 == 0:
1738                    self.inner_impl = self.impl1
1739                else:
1740                    self.inner_impl = self.impl2
1741                return inp.relu() + self.inner_impl(inp.sin())
1742
1743        class WrapperModule(torch.nn.Module):
1744            outer_impl: OuterInterface
1745
1746            def __init__(self) -> None:
1747                super().__init__()
1748                self.outer_impl = OuterImpl()
1749
1750            def forward(self, inp):
1751                return self.outer_impl(inp) + inp
1752
1753        m = WrapperModule()
1754
1755        m_s = torch.jit.script(m)
1756        m_s.eval()
1757        with self.assertRaisesRegex(
1758            RuntimeError, "Freezing does not support SetAttr on an interface type"
1759        ):
1760            m_s = torch.jit.freeze(m_s)
1761
1762    def test_freeze_interface_swapping_two_methods(self):
1763        @torch.jit.interface
1764        class MyInterface(torch.nn.Module):
1765            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1766                pass
1767
1768        class Impl1(torch.nn.Module):
1769            def forward(self, inp):
1770                return inp.cos()
1771
1772        class Impl2(torch.nn.Module):
1773            def forward(self, inp):
1774                return inp.sin()
1775
1776        class WrapperModule1(torch.nn.Module):
1777            interface_impl: MyInterface
1778
1779            def __init__(self) -> None:
1780                super().__init__()
1781                self.interface_impl = Impl1()
1782                self.impl1 = Impl1()
1783                self.impl2 = Impl2()
1784                self.idx = 0
1785
1786            def forward(self, x):
1787                return self.interface_impl(x)
1788
1789            @torch.jit.export
1790            def other_method(self, x):
1791                self.idx += 1
1792                if self.idx % 2 == 0:
1793                    self.interface_impl = self.impl1
1794                else:
1795                    self.interface_impl = self.impl2
1796                return self.interface_impl(x)
1797
1798        class WrapperModule2(torch.nn.Module):
1799            interface_impl: MyInterface
1800
1801            def __init__(self) -> None:
1802                super().__init__()
1803                self.interface_impl = Impl1()
1804                self.impl1 = Impl1()
1805                self.impl2 = Impl2()
1806                self.idx = 0
1807
1808            def forward(self, x):
1809                self.idx += 1
1810                if self.idx % 2 == 0:
1811                    self.interface_impl = self.impl1
1812                else:
1813                    self.interface_impl = self.impl2
1814                return self.interface_impl(x)
1815
1816            @torch.jit.export
1817            def other_method(self, x):
1818                return self.interface_impl(x)
1819
1820        m1 = torch.jit.script(WrapperModule1())
1821        m2 = torch.jit.script(WrapperModule2())
1822
1823        m1.eval()
1824        m2.eval()
1825
1826        with self.assertRaisesRegex(
1827            RuntimeError, "Freezing does not support SetAttr on an interface type"
1828        ):
1829            torch.jit.freeze(m1, preserved_attrs=["other_method"])
1830
1831        with self.assertRaisesRegex(
1832            RuntimeError, "Freezing does not support SetAttr on an interface type"
1833        ):
1834            torch.jit.freeze(m2, preserved_attrs=["other_method"])
1835
1836    def test_freeze_recursive_interfaces_same_name(self):
1837        @torch.jit.interface
1838        class InnerInterface(torch.nn.Module):
1839            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1840                pass
1841
1842        @torch.jit.interface
1843        class OuterInterface(torch.nn.Module):
1844            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1845                pass
1846
1847        class InnerImpl(torch.nn.Module):
1848            def __init__(self) -> None:
1849                super().__init__()
1850                self.x = torch.ones((2, 2))
1851
1852            def forward(self, inp):
1853                return inp.cos() * self.x
1854
1855        class OuterImpl(torch.nn.Module):
1856            impl: InnerInterface
1857
1858            def __init__(self) -> None:
1859                super().__init__()
1860                self.impl = InnerImpl()
1861                self.x = torch.ones((2, 2)) * 5
1862
1863            def forward(self, inp):
1864                return self.other_method(inp)
1865
1866            def other_method(self, inp):
1867                return inp.relu() + self.impl(inp.sin()) + self.x
1868
1869        class WrapperModule(torch.nn.Module):
1870            impl: OuterInterface
1871
1872            def __init__(self) -> None:
1873                super().__init__()
1874                self.impl = OuterImpl()
1875
1876            def forward(self, inp):
1877                return self.impl(inp) + inp
1878
1879        m = WrapperModule()
1880        x = torch.rand((2, 2))
1881        expected = m(x)
1882
1883        m_s = torch.jit.script(m)
1884        m_s.eval()
1885        m_s = torch.jit.freeze(m_s)
1886        actual = m_s(x)
1887
1888        self.assertEqual(expected, actual)
1889
1890    def test_freeze_non_interface_module_swap(self):
1891        class InnerModule(torch.nn.Module):
1892            def __init__(self, x):
1893                super().__init__()
1894                self.x = x
1895
1896            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1897                return inp.relu() + self.x
1898
1899        class WrapperModule(torch.nn.Module):
1900            def __init__(self) -> None:
1901                super().__init__()
1902                self.option1 = InnerModule(torch.rand((2, 2)))
1903                self.option2 = InnerModule(torch.rand((2, 2)))
1904                self.impl = self.option1
1905                self.idx = 0
1906
1907            def forward(self, x: torch.Tensor) -> torch.Tensor:
1908                self.idx += 1
1909                if self.idx % 2 == 1:
1910                    self.impl = self.option1
1911                else:
1912                    self.impl = self.option2
1913                return self.impl(x)
1914
1915        unfrozen = WrapperModule()
1916        m = torch.jit.script(unfrozen)
1917        m.eval()
1918        m_frozen = torch.jit.freeze(m)
1919
1920        x = torch.rand((2, 2))
1921        expected = unfrozen(x)
1922        actual = m_frozen(x)
1923        self.assertEqual(expected, actual)
1924
1925    @unittest.expectedFailure
1926    def test_freeze_interface_within_object(self):
1927        # I don't think there's any way to create a plain python object that
1928        # contains a torch.nn.Module inside it, but just in case... I'm not
1929        # sure freezing would handle this case correctly, so marking as xfail
1930        # so that if this ever _does_ start working someone will need to
1931        # investigate to make sure this is handled correctly.
1932        class MyIface(torch.nn.Module):
1933            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1934                pass
1935
1936        class MyImpl(torch.nn.Module):
1937            def forward(self, inp: torch.Tensor) -> torch.Tensor:
1938                return inp.sin()
1939
1940        class MyObject:
1941            impl: MyIface
1942
1943            def run(self, x):
1944                return self.impl(x)
1945
1946        class WrapperModule(torch.nn.Module):
1947            impl: MyObject
1948
1949            def __init__(self) -> None:
1950                super().__init__()
1951                self.impl = MyObject()
1952                self.impl.impl = MyImpl()
1953
1954            def forward(self, x: torch.Tensor) -> torch.Tensor:
1955                return self.impl(x)
1956
1957        unfrozen = WrapperModule()
1958        m = torch.jit.script(unfrozen)
1959        m.eval()
1960        m_frozen = torch.jit.freeze(m)
1961
1962        x = torch.rand((2, 2))
1963        expected = unfrozen(x)
1964        actual = m_frozen(x)
1965        self.expectEqual(expected, actual)
1966
1967    def test_freeze_non_module_class_getattr(self):
1968        class BoxCoder:
1969            def __init__(self, bbox_xform_clip):
1970                # type: (float) -> None
1971                self.bbox_xform_clip = bbox_xform_clip
1972
1973            def decode(self, input):
1974                return input * self.bbox_xform_clip
1975
1976        class MyModule(torch.nn.Module):
1977            __annotations__ = {
1978                "box_coder": BoxCoder,
1979            }
1980
1981            def __init__(self) -> None:
1982                super().__init__()
1983                self.box_coder = BoxCoder(50.0)
1984
1985            def forward(self, input):
1986                return self.box_coder.decode(input)
1987
1988        model = MyModule()
1989        model.eval()
1990        script_model = torch.jit.freeze(torch.jit.script(model))
1991        inp = torch.randn([4, 4])
1992        output_eager = model(inp)
1993        self.assertEqual(model(inp), script_model(inp))
1994        FileCheck().check_not("GetAttr").run(script_model.graph)
1995
1996    def test_freeze_module_with_tupleoutput_submodule(self):
1997        class SubModule(nn.Module):
1998            def forward(self, x):
1999                return (x + 1, x + 2)
2000
2001        class TestModule(nn.Module):
2002            def __init__(self) -> None:
2003                super().__init__()
2004                self.sub = SubModule()
2005
2006            def forward(self, x):
2007                y1, y2 = self.sub(x)
2008                return y1 + y2
2009
2010        m = torch.jit.script(TestModule())
2011        m = m.eval()
2012        mf = torch.jit.freeze(m)
2013        inp = torch.randn(2, 2)
2014        expected = m.forward(inp)
2015        output = mf.forward(inp)
2016        # Check if prim::TupleConstruct and prim::TupleUnpack
2017        # Don't exist in frozen graph
2018        FileCheck().check_not("prim::TupleConstruct").run(mf.graph)
2019        FileCheck().check_not("prim::TupleUnpack").run(mf.graph)
2020        self.assertEqual(output, expected)
2021
2022    def test_freeze_module_with_call_method(self):
2023        class Mod(nn.Module):
2024            def __init__(self, val):
2025                super().__init__()
2026                self.param = nn.Parameter(val)
2027
2028            def forward(self, x):
2029                # this method will change during freezing
2030                return x + self.param
2031
2032            @torch.jit.export
2033            def make_prediction(self, x):
2034                y = x + x
2035                return self.forward(y)
2036
2037        param = torch.rand([2, 2])
2038        x = torch.rand([2, 2])
2039
2040        unscripted_mod = Mod(param)
2041        mod = torch.jit.script(unscripted_mod)
2042        mod.eval()
2043        mod = torch.jit.freeze(mod, preserved_attrs=["make_prediction"])
2044
2045        self.assertEqual(
2046            mod.forward(x), unscripted_mod.forward(x), atol=1e-5, rtol=1e-5
2047        )
2048
2049
2050@skipIfTorchDynamo("somehow causing hanging during python shutdown")
2051class TestFrozenOptimizations(JitTestCase):
2052    def setUp(self):
2053        super().setUp()
2054        self.default_dtype = torch.get_default_dtype()
2055        torch.set_default_dtype(torch.double)
2056
2057    def tearDown(self):
2058        torch.set_default_dtype(self.default_dtype)
2059        super().tearDown()
2060
2061    def test_conv_bn_folding(self):
2062        conv_bias = [True, False]
2063        module_pairs = [
2064            (nn.Conv1d, nn.BatchNorm1d),
2065            (nn.Conv2d, nn.BatchNorm2d),
2066            (nn.Conv3d, nn.BatchNorm3d),
2067        ]
2068        use_tracing = [True, False]
2069        bn_running_stats = [True, False]
2070
2071        for use_bias, modules, tracing, track_stats in product(
2072            conv_bias, module_pairs, use_tracing, bn_running_stats
2073        ):
2074
2075            class ConvBN(torch.nn.Module):
2076                def __init__(self, in_channels, out_channels, **kwargs):
2077                    super().__init__()
2078                    self.conv = modules[0](
2079                        in_channels, out_channels, bias=use_bias, **kwargs
2080                    )
2081                    self.bn = modules[1](
2082                        out_channels, eps=0.001, track_running_stats=track_stats
2083                    )
2084
2085                def forward(self, x):
2086                    x = self.conv(x)
2087                    return self.bn(x)
2088
2089            mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
2090            inps = [4, 3, 4]
2091            if modules[0] == nn.Conv2d:
2092                inps.append(inps[-1])
2093            if modules[0] == nn.Conv3d:
2094                inps.append(inps[-1])
2095                inps.append(inps[-1])
2096
2097            inp = torch.rand(inps)
2098
2099            if tracing:
2100                scripted_mod = torch.jit.trace(mod_eager, (inp))
2101            else:
2102                scripted_mod = torch.jit.script(mod_eager)
2103
2104            self.run_pass("inline", scripted_mod.graph)
2105            self.run_pass("peephole", scripted_mod.graph)
2106            self.run_pass("constant_propagation", scripted_mod.graph)
2107
2108            FileCheck().check("conv").check("batch").run(scripted_mod.graph)
2109            # successfully no-ops with non-const inputs
2110            self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
2111            FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph)
2112
2113            scripted_mod = torch.jit.freeze(scripted_mod)
2114            self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
2115            if track_stats:
2116                FileCheck().check("conv").check_not("aten::batch_norm").run(
2117                    scripted_mod.graph
2118                )
2119            else:
2120                FileCheck().check("conv").check("aten::batch_norm").run(
2121                    scripted_mod.graph
2122                )
2123
2124            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2125            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2126
2127    def test_conv_bn_folding_not_forward(self):
2128        class ConvBN(torch.nn.Module):
2129            def __init__(self, in_channels, out_channels, **kwargs):
2130                super().__init__()
2131                self.conv = torch.nn.Conv2d(
2132                    in_channels, out_channels, bias=True, **kwargs
2133                )
2134                self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
2135                self.amt = 3.2
2136
2137            def forward(self, x):
2138                x = self.conv(x)
2139                return self.bn(x)
2140
2141            @torch.jit.export
2142            def make_prediction(self, x):
2143                return self.forward(x) + self.amt
2144
2145        mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
2146        scripted_mod = torch.jit.script(mod_eager)
2147        torch._C._jit_pass_inline(scripted_mod.make_prediction.graph)
2148        FileCheck().check("conv").check("aten::batch_norm").run(
2149            scripted_mod.make_prediction.graph
2150        )
2151
2152        # _jit_pass_optimize_frozen_graph should not be called on non-method attributes (e.g. "amt")
2153        scripted_mod = torch.jit.freeze(
2154            scripted_mod, preserved_attrs=["make_prediction", "amt"]
2155        )
2156        FileCheck().check("conv").check_not("aten::batch_norm").run(
2157            scripted_mod.make_prediction.graph
2158        )
2159
2160    # During freezing this creates tensors constants that are attached to the frozen graph,
2161    # which is then kept alive by the compilation unit (which causes a leak)
2162    @skipCUDAMemoryLeakCheckIf(True)
2163    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2164    def test_conv_bn_folding_autocast_scenario_cuda(self):
2165        # CUDA conv takes input tensors which must all be the same dtype,
2166        # which can cause issues if folding produces inputs of different dtypes.
2167
2168        class ConvBN(torch.nn.Module):
2169            def __init__(self, in_channels, out_channels, **kwargs):
2170                super().__init__()
2171                self.conv = torch.nn.Conv2d(
2172                    in_channels, out_channels, bias=False, dtype=torch.half, **kwargs
2173                )
2174                self.bn = torch.nn.BatchNorm2d(
2175                    out_channels, eps=0.001, dtype=torch.float
2176                )
2177
2178            def forward(self, x):
2179                return self.bn(self.conv(x))
2180
2181        mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).cuda().eval()
2182        scripted_mod = torch.jit.script(mod_eager)
2183        scripted_mod = torch.jit.freeze(scripted_mod)
2184        FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph)
2185        conv_node = scripted_mod.graph.findNode("aten::conv2d", True)
2186        self.assertTrue(conv_node is not None)
2187        bias_input = conv_node.namedInput("bias")
2188        self.assertTrue(bias_input is not None)
2189        self.assertTrue(bias_input.type().dtype() == torch.half)
2190
2191        x = torch.rand((3, 3, 32, 32), dtype=torch.half).cuda()
2192
2193        self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2194        self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2195
2196    def test_conv_add_folding(self):
2197        @torch.no_grad()
2198        def test_conv_fusion(
2199            use_bias, module, tracing, op, scalar, add_tensor, expect_success
2200        ):
2201            class ConvOp(torch.nn.Module):
2202                __constants__ = ["use_scalar"]
2203
2204                def __init__(self, in_channels, out_channels, tensor=None, **kwargs):
2205                    super().__init__()
2206                    self.conv = module(
2207                        in_channels, out_channels, bias=use_bias, **kwargs
2208                    )
2209                    self.conv2 = module(
2210                        in_channels, out_channels, bias=use_bias, **kwargs
2211                    )
2212                    self.use_scalar = scalar
2213                    tensor_size = [1 for _ in range(self.conv.weight.ndim)]
2214                    tensor_size[1] = self.conv.weight.size(0)
2215                    self.tensor = (
2216                        add_tensor
2217                        if add_tensor is not None
2218                        else torch.rand(tensor_size)
2219                    )
2220                    self.op = op
2221
2222                def forward(self, x):
2223                    x = self.conv(x)
2224                    if self.use_scalar:
2225                        return self.op(x, 2.0)
2226                    else:
2227                        return self.op(x, self.tensor)
2228
2229            mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval()
2230
2231            inps = [4, 3, 4]
2232            if module == nn.Conv2d:
2233                inps.append(inps[-1])
2234            if module == nn.Conv3d:
2235                inps.append(inps[-1])
2236                inps.append(inps[-1])
2237
2238            inp = torch.rand(inps)
2239
2240            if tracing:
2241                scripted_mod = torch.jit.trace(mod_eager, (inp,))
2242            else:
2243                scripted_mod = torch.jit.script(mod_eager)
2244
2245            self.run_pass("inline", scripted_mod.graph)
2246            op_str = "aten::" + op.__name__
2247
2248            FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
2249            # successively no-ops with non-const inputs
2250            self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
2251            self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)
2252            FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
2253            scripted_mod = torch.jit.freeze(scripted_mod)
2254            self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
2255            self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)
2256
2257            if expect_success:
2258                FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph)
2259            else:
2260                FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
2261
2262            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2263            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2264
2265        conv_bias = [True, False]
2266        modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
2267        use_tracing = [False, True]
2268        use_scalar = [False, True]
2269        ops = [torch.add, torch.sub, torch.mul, torch.div]
2270
2271        for use_bias, module, tracing, pytorch_op, scalar in product(
2272            conv_bias, modules, use_tracing, ops, use_scalar
2273        ):
2274            test_conv_fusion(
2275                use_bias,
2276                module,
2277                tracing,
2278                pytorch_op,
2279                scalar,
2280                add_tensor=None,
2281                expect_success=True,
2282            )
2283
2284        for use_bias, pytorch_op in product(conv_bias, ops):
2285            # broadcasting add
2286            test_conv_fusion(
2287                use_bias,
2288                nn.Conv2d,
2289                False,
2290                pytorch_op,
2291                False,
2292                add_tensor=torch.rand(32, 1, 32),
2293                expect_success=False,
2294            )
2295
2296            # broadcasting add
2297            test_conv_fusion(
2298                use_bias,
2299                nn.Conv2d,
2300                False,
2301                pytorch_op,
2302                False,
2303                add_tensor=torch.rand(1, 1),
2304                expect_success=True,
2305            )
2306
2307            # add with different dtype
2308            test_conv_fusion(
2309                use_bias,
2310                nn.Conv2d,
2311                False,
2312                pytorch_op,
2313                False,
2314                add_tensor=torch.tensor([2]).to(torch.int),
2315                expect_success=True,
2316            )
2317
2318    def test_conv_mul_add_bn(self):
2319        class Conv_Mul_Add_Bn(nn.Module):
2320            def __init__(self, in_channels, out_channels, **kwargs):
2321                super().__init__()
2322                self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
2323                self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
2324                self.tensor1 = torch.tensor(2.2)
2325                self.tensor2 = torch.tensor(2)
2326
2327            def forward(self, x):
2328                return self.bn(
2329                    torch.add(torch.mul(self.conv(x), self.tensor1), self.tensor2)
2330                )
2331
2332        input = torch.randn(8, 3, 64, 64)
2333        model = Conv_Mul_Add_Bn(3, 32, kernel_size=3, stride=1).eval()
2334
2335        with torch.no_grad():
2336            result = model(input)
2337            traced_model = torch.jit.trace(model, input).eval()
2338            traced_model = torch.jit.freeze(traced_model)
2339            tresult = traced_model(input)
2340            self.assertEqual(result, tresult)
2341            FileCheck().check("conv").check_not("aten::batch_norm").run(
2342                traced_model.graph
2343            )
2344            FileCheck().check("conv").check_not("aten::add").run(traced_model.graph)
2345
2346    def test_linear_bn_folding(self):
2347        module_pairs = [
2348            (nn.Linear, nn.BatchNorm1d),
2349            (nn.Linear, nn.BatchNorm2d),
2350            (nn.Linear, nn.BatchNorm3d),
2351        ]
2352        use_tracing = [True, False]
2353        bn_running_stats = [True, False]
2354
2355        for modules, tracing, track_stats in product(
2356            module_pairs, use_tracing, bn_running_stats
2357        ):
2358
2359            class LinearBN(torch.nn.Module):
2360                def __init__(self, in_features, out_features):
2361                    super().__init__()
2362                    self.linear = modules[0](in_features, out_features)
2363                    self.bn = modules[1](
2364                        out_features, eps=0.001, track_running_stats=track_stats
2365                    )
2366
2367                def forward(self, x):
2368                    x = self.linear(x)
2369                    return self.bn(x)
2370
2371            mod_eager = LinearBN(32, 32).eval()
2372
2373            inps = [3, 32]
2374            if modules[1] == nn.BatchNorm2d:
2375                inps.append(inps[-1])
2376                inps.append(inps[-1])
2377            if modules[1] == nn.BatchNorm3d:
2378                inps.append(inps[-1])
2379                inps.append(inps[-1])
2380                inps.append(inps[-1])
2381
2382            inp = torch.rand(inps)
2383
2384            if tracing:
2385                scripted_mod = torch.jit.trace(mod_eager, (inp))
2386            else:
2387                scripted_mod = torch.jit.script(mod_eager)
2388
2389            self.run_pass("inline", scripted_mod.graph)
2390            self.run_pass("peephole", scripted_mod.graph)
2391            self.run_pass("constant_propagation", scripted_mod.graph)
2392
2393            FileCheck().check("linear").check("batch").run(scripted_mod.graph)
2394            # successfully no-ops with non-const inputs
2395            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2396            FileCheck().check("linear").check("aten::batch_norm").run(
2397                scripted_mod.graph
2398            )
2399
2400            scripted_mod = torch.jit.freeze(scripted_mod)
2401            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2402            if track_stats:
2403                FileCheck().check("linear").check_not("aten::batch_norm").run(
2404                    scripted_mod.graph
2405                )
2406            else:
2407                FileCheck().check("linear").check("aten::batch_norm").run(
2408                    scripted_mod.graph
2409                )
2410
2411            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2412            self.assertEqual(mod_eager(inp), scripted_mod(inp))
2413
2414    def test_bn_not_broadcast_with_linear(self):
2415        module_pairs = [
2416            (nn.Linear, nn.BatchNorm1d),
2417            (nn.Linear, nn.BatchNorm2d),
2418            (nn.Linear, nn.BatchNorm3d),
2419        ]
2420        use_tracing = [True, False]
2421        linear_in = 3
2422        # (linear_out, bn_in)
2423        # case 1: linear_out < bn_in
2424        # case 2: linear_out > bn_in
2425        # case 3: linear_out != bn_in && linear_out = 1
2426        dims = [(2, 4), (4, 2), (1, 2)]
2427
2428        for modules, tracing, dim in product(module_pairs, use_tracing, dims):
2429            linear_out, bn_in = dim[0], dim[1]
2430
2431            linear = modules[0](linear_in, linear_out)
2432            bn = modules[1](bn_in)
2433            mod_eager = nn.Sequential(linear, bn).eval()
2434
2435            N, C = 3, bn_in
2436            input_shape = [N, C]
2437            if modules[1] == nn.BatchNorm1d:
2438                H = linear_in
2439                input_shape.append(H)
2440            elif modules[1] == nn.BatchNorm2d:
2441                H, W = 4, linear_in
2442                input_shape.append(H)
2443                input_shape.append(W)
2444            elif modules[1] == nn.BatchNorm3d:
2445                D, H, W = 4, 4, linear_in
2446                input_shape.append(D)
2447                input_shape.append(H)
2448                input_shape.append(W)
2449
2450            inp = torch.rand(input_shape)
2451
2452            if tracing:
2453                scripted_mod = torch.jit.trace(mod_eager, (inp))
2454            else:
2455                scripted_mod = torch.jit.script(mod_eager)
2456
2457            self.run_pass("inline", scripted_mod.graph)
2458            self.run_pass("peephole", scripted_mod.graph)
2459            self.run_pass("constant_propagation", scripted_mod.graph)
2460
2461            FileCheck().check("linear").check("batch").run(scripted_mod.graph)
2462            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2463            FileCheck().check("linear").check("aten::batch_norm").run(
2464                scripted_mod.graph
2465            )
2466
2467            frozen_mod = torch.jit.freeze(scripted_mod)
2468            self.run_pass("fold_frozen_linear_bn", frozen_mod.graph)
2469            # successfully skipped folding
2470            FileCheck().check("linear").check("aten::batch_norm").run(frozen_mod.graph)
2471
2472            self.assertEqual(mod_eager(inp), frozen_mod(inp))
2473            self.assertEqual(mod_eager(inp), frozen_mod(inp))
2474
2475            # successfully failed folding
2476            with self.assertRaisesRegex(
2477                AssertionError,
2478                "To fuse, linear.out_features == bn.num_features or bn.num_features == 1",
2479            ):
2480                nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
2481
2482    @skipCUDAMemoryLeakCheckIf(True)
2483    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2484    def test_linear_bn_folding_autocast_scenario_cuda(self):
2485        module_pairs = [
2486            (nn.Linear, nn.BatchNorm1d),
2487            (nn.Linear, nn.BatchNorm2d),
2488            (nn.Linear, nn.BatchNorm3d),
2489        ]
2490        use_tracing = [True, False]
2491        bn_running_stats = [True, False]
2492
2493        for modules, tracing, track_stats in product(
2494            module_pairs, use_tracing, bn_running_stats
2495        ):
2496
2497            class LinearBN(torch.nn.Module):
2498                def __init__(self, in_features, out_features):
2499                    super().__init__()
2500                    self.linear = modules[0](
2501                        in_features, out_features, bias=False, dtype=torch.half
2502                    )
2503                    self.bn = modules[1](out_features, eps=0.001, dtype=torch.float)
2504
2505                def forward(self, x):
2506                    x = self.linear(x)
2507                    return self.bn(x)
2508
2509            mod_eager = LinearBN(32, 32).cuda().eval()
2510
2511            inps = [3, 32]
2512            if modules[1] == nn.BatchNorm2d:
2513                inps.append(inps[-1])
2514                inps.append(inps[-1])
2515            if modules[1] == nn.BatchNorm3d:
2516                inps.append(inps[-1])
2517                inps.append(inps[-1])
2518                inps.append(inps[-1])
2519
2520            x = torch.rand(inps, dtype=torch.half).cuda()
2521
2522            if tracing:
2523                scripted_mod = torch.jit.trace(mod_eager, (x))
2524            else:
2525                scripted_mod = torch.jit.script(mod_eager)
2526            scripted_mod = torch.jit.freeze(scripted_mod)
2527            FileCheck().check("linear").check_not("aten::batch_norm").run(
2528                scripted_mod.graph
2529            )
2530            lin_node = scripted_mod.graph.findNode("aten::linear", True)
2531            self.assertTrue(lin_node is not None)
2532            weight_input = lin_node.namedInput("weight")
2533            bias_input = lin_node.namedInput("bias")
2534            self.assertTrue(bias_input is not None)
2535            self.assertTrue(weight_input.type().dtype() == torch.half)
2536            self.assertTrue(bias_input.type().dtype() == torch.half)
2537
2538            self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2539            self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2540
2541    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2542    def test_linear_concat(self):
2543        out_dimms = [[5, 10], [1, 5]]
2544
2545        for w1_dim, w2_dim in out_dimms:
2546
2547            class ModMultLinear(nn.Module):
2548                def __init__(self, w1_dim, w2_dim):
2549                    super().__init__()
2550                    self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2551                    self.b1 = nn.Parameter(torch.rand([w1_dim]))
2552                    self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2553                    self.b2 = nn.Parameter(torch.rand([w2_dim]))
2554
2555                def forward(self, in_tensor1):
2556                    res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2557                    res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2)
2558                    return res1, res2
2559
2560            mod_eager = ModMultLinear(w1_dim, w2_dim).eval()
2561
2562            test_val1 = torch.rand([50, 5])
2563            self.check_linear_optimizations(mod_eager, 2, 1, (test_val1,))
2564
2565    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2566    def test_linear_concat_complex(self):
2567        """
2568        Testing that the interleaving of multiple optimizations does not
2569        cause errors, and gets optimized as expected
2570        """
2571
2572        class ModMultLinear(nn.Module):
2573            def __init__(self) -> None:
2574                super().__init__()
2575                w1_dim = 5
2576                w2_dim = 10
2577                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2578                self.b1 = nn.Parameter(torch.rand([w1_dim]))
2579                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2580                self.b2 = nn.Parameter(torch.rand([w2_dim]))
2581
2582            def forward(self, in_tensor1):
2583                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2584                res3 = torch._C._nn.linear(res1, self.w2, self.b2)
2585                res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2)
2586                res4 = torch._C._nn.linear(res1, self.w1, self.b1)
2587                return res2, res3, res4
2588
2589        mod_eager = ModMultLinear().eval()
2590        test_val1 = torch.rand([50, 5])
2591        self.check_linear_optimizations(mod_eager, 4, 2, (test_val1,))
2592
2593    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2594    def test_linear_concat_different_input(self):
2595        """
2596        There should be no change to the graph due to the optimization pass
2597        due to the two input tensors being different
2598        """
2599
2600        # Freezing requires that the graph be a module
2601        class ModMultLinear(nn.Module):
2602            def __init__(self, w1_dim, w2_dim):
2603                super().__init__()
2604                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2605                self.b1 = nn.Parameter(torch.rand([w1_dim]))
2606                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2607                self.b2 = nn.Parameter(torch.rand([w2_dim]))
2608
2609            def forward(self, in_tensor1, in_tensor2):
2610                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2611                res2 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
2612                return res1, res2
2613
2614        mod_eager = ModMultLinear(5, 5).eval()
2615        test_val1 = torch.rand([50, 5])
2616        test_val2 = torch.rand([50, 5])
2617        self.check_linear_optimizations(mod_eager, 2, 2, (test_val1, test_val2))
2618
2619    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2620    def test_linear_multiple_blocks(self):
2621        class ModMultLinear(nn.Module):
2622            def __init__(self, w1_dim, w2_dim):
2623                super().__init__()
2624                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
2625                self.b1 = nn.Parameter(torch.rand([w1_dim]))
2626                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
2627                self.b2 = nn.Parameter(torch.rand([w2_dim]))
2628
2629            def forward(self, in_tensor1, in_tensor2, cond: bool):
2630                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
2631                if cond:
2632                    res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
2633                    res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
2634                else:
2635                    raise AssertionError
2636                res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
2637                return res1, res2, res3, res4
2638
2639        mod_eager = ModMultLinear(5, 5).eval()
2640        test_val1 = torch.rand([50, 5])
2641        test_val2 = torch.rand([50, 5])
2642        self.check_linear_optimizations(mod_eager, 4, 3, (test_val1, test_val2, True))
2643
2644    def check_linear_optimizations(
2645        self, eager_mod, orig_linears, new_linears, test_vals
2646    ):
2647        for is_cuda in [False, True]:
2648            if is_cuda:
2649                mod_to_device = eager_mod.cuda()
2650                test_vals_to_device = [
2651                    t.cuda() if isinstance(t, torch.Tensor) else t for t in test_vals
2652                ]
2653            else:
2654                mod_to_device = eager_mod
2655                test_vals_to_device = test_vals
2656
2657            script_mod = torch.jit.script(mod_to_device)
2658            op_graph = script_mod.graph
2659
2660            FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2661                op_graph
2662            )
2663            # successively no-ops with non-const inputs
2664            self.run_pass("concat_frozen_linear", op_graph)
2665            FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2666                op_graph
2667            )
2668
2669            script_mod = torch.jit.freeze(script_mod)
2670            op_graph = script_mod.graph
2671            self.run_pass("concat_frozen_linear", op_graph)
2672            if is_cuda:
2673                FileCheck().check_count("aten::linear", new_linears, exactly=True).run(
2674                    op_graph
2675                )
2676            else:
2677                FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2678                    op_graph
2679                )
2680
2681            self.assertEqual(
2682                mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)
2683            )
2684
2685    def test_optimize_freeze_module(self):
2686        in_channels, out_channels = 3, 32
2687        conv = torch.nn.Conv2d(
2688            in_channels, out_channels, kernel_size=3, stride=2, bias=True
2689        )
2690        bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
2691        mod = torch.nn.Sequential(conv, bn)
2692        # set optimize to False here, by default freezing runs run_frozen_optimizations
2693        frozen_mod = torch.jit.freeze(
2694            torch.jit.script(mod.eval()), optimize_numerics=False
2695        )
2696        # inspect frozen mod
2697        FileCheck().check("batch_norm").run(frozen_mod.graph)
2698        torch.jit.run_frozen_optimizations(frozen_mod)
2699        FileCheck().check_not("batch_norm").run(frozen_mod.graph)
2700
2701        # run_frozen_optimizations should be run
2702        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
2703        FileCheck().check_not("batch_norm").run(frozen_mod.graph)
2704
2705    def test_freeze_remove_dropout(self):
2706        class Net(nn.Module):
2707            def __init__(self) -> None:
2708                super().__init__()
2709                self.dropout = nn.Dropout(0.5)
2710
2711            def forward(self, x):
2712                return self.dropout(x)
2713
2714        mod = torch.jit.script(Net())
2715        # inspect mod
2716        torch._C._jit_pass_inline(mod.graph)
2717        FileCheck().check("aten::dropout").run(mod.graph)
2718        frozen_mod = torch.jit.freeze(mod.eval())
2719        FileCheck().check_not("aten::dropout").run(frozen_mod.graph)
2720
2721        input = torch.randn(2)
2722        output_s = mod.forward(input)
2723        output_f = frozen_mod.forward(input)
2724        self.assertEqual(output_s, output_f)
2725
2726    def test_freeze_remove_feature_dropout(self):
2727        class Net(nn.Module):
2728            def __init__(self) -> None:
2729                super().__init__()
2730                self.dropout = nn.Dropout2d(0.5)
2731
2732            def forward(self, x):
2733                return self.dropout(x)
2734
2735        mod = torch.jit.script(Net().eval())
2736        # inspect mod
2737        torch._C._jit_pass_inline(mod.graph)
2738        FileCheck().check("aten::feature_dropout").run(mod.graph)
2739        frozen_mod = torch.jit.freeze(mod)
2740        FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph)
2741
2742        input = torch.randn(2, 2, 1, 1)
2743        output_s = mod.forward(input)
2744        output_f = frozen_mod.forward(input)
2745        self.assertEqual(output_s, output_f)
2746
2747    @unittest.skipIf(
2748        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2749    )
2750    def test_freeze_mkdlnn(self):
2751        conv = torch.nn.Conv2d(3, 32, kernel_size=3, stride=2).eval().float()
2752        convmkl = mkldnn_utils.to_mkldnn(conv)
2753        out = torch.jit.freeze(torch.jit.script(convmkl.eval()))
2754        inp = torch.rand([4, 3, 4, 4]).float()
2755        self.assertEqual(out(inp.to_mkldnn()).to_dense(), conv(inp))
2756
2757    @unittest.skipIf(
2758        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2759    )
2760    def test_conv_to_mkldnn(self):
2761        with set_default_dtype(torch.float):
2762            for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]):
2763                mod = module(3, 32, kernel_size=3, stride=2).eval()
2764                inps = [4, 3, 4]
2765                if module == nn.Conv2d:
2766                    inps.append(inps[-1])
2767                if module == nn.Conv3d:
2768                    inps.append(inps[-1])
2769                    inps.append(inps[-1])
2770
2771                inp = torch.rand(inps)
2772                if trace:
2773                    scripted_mod = torch.jit.script(mod)
2774                else:
2775                    scripted_mod = torch.jit.trace(mod, (inp,))
2776
2777                self.run_pass("inline", scripted_mod.graph)
2778
2779                FileCheck().check("conv").run(scripted_mod.graph)
2780                # successfully no-ops with non-const inputs
2781                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2782                FileCheck().check_not("to_mkldnn").run(scripted_mod.graph)
2783
2784                scripted_mod = torch.jit.freeze(scripted_mod)
2785                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2786                FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check(
2787                    "to_dense"
2788                ).run(scripted_mod.graph)
2789
2790                self.assertEqual(mod(inp), scripted_mod(inp))
2791                self.assertEqual(mod(inp), scripted_mod(inp))
2792
2793    def test_linear_transpose(self):
2794        class ModLinear(torch.nn.Module):
2795            def __init__(self) -> None:
2796                super().__init__()
2797                self.bias = torch.nn.Parameter(torch.rand(30))
2798                self.weight = torch.nn.Parameter(torch.rand([30, 20]))
2799
2800            def forward(self, x):
2801                return torch._C._nn.linear(x, self.weight, self.bias)
2802
2803        mod_eager = ModLinear().eval()
2804        test_val = torch.rand([50, 20])
2805        self.check_linear_optimizations_2(
2806            mod_eager, 1, 0, "transpose_frozen_linear", (test_val,)
2807        )
2808
2809    def test_linear_non_constant_weight(self):
2810        class ModLinear(torch.nn.Module):
2811            def __init__(self) -> None:
2812                super().__init__()
2813                self.bias = torch.nn.Parameter(torch.rand(30))
2814
2815            def forward(self, x, weight):
2816                return torch._C._nn.linear(x, weight, self.bias)
2817
2818        mod_eager = ModLinear().eval()
2819        test_val = torch.rand([50, 20])
2820        test_weight = torch.rand([30, 20])
2821        self.check_linear_optimizations_2(
2822            mod_eager, 1, 1, "transpose_frozen_linear", (test_val, test_weight)
2823        )
2824
2825    def check_linear_optimizations_2(
2826        self, eager_mod, orig_linears, new_linears, opt_pass, test_vals
2827    ):
2828        # TODO: merge with check_linear_optimizations once both diffs land
2829        mod_to_device = eager_mod
2830        test_vals_to_device = test_vals
2831
2832        script_mod = torch.jit.script(mod_to_device)
2833        op_graph = script_mod.graph
2834
2835        FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2836            op_graph
2837        )
2838        # successively no-ops with non-const inputs
2839        self.run_pass(opt_pass, op_graph)
2840        FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
2841            op_graph
2842        )
2843
2844        script_mod = torch.jit.freeze(script_mod)
2845        op_graph = script_mod.graph
2846        self.run_pass(opt_pass, op_graph)
2847        FileCheck().check_count("aten::linear", new_linears, exactly=True).run(op_graph)
2848
2849        self.assertEqual(
2850            mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)
2851        )
2852
2853    @staticmethod
2854    def conv():
2855        # Generic composable conv for testing purposes
2856        return nn.Conv2d(8, 8, 1)
2857
2858    @unittest.skipIf(
2859        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2860    )
2861    def test_collapse_adjacent_conversions(self):
2862        with set_default_dtype(torch.float):
2863            mod = nn.Sequential(self.conv(), self.conv()).eval()
2864            scripted_mod = torch.jit.script(mod)
2865            scripted_mod = torch.jit.freeze(scripted_mod)
2866            self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2867            FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check(
2868                "prim::mkldnn_convolution"
2869            ).check("to_dense").run(scripted_mod.graph)
2870            FileCheck().check_count("to_mkldnn", 1, exactly=True).run(
2871                scripted_mod.graph
2872            )
2873
2874            inp = torch.rand([1, 8, 8, 8])
2875            self.assertEqual(scripted_mod(inp), mod(inp))
2876            self.assertEqual(scripted_mod(inp), mod(inp))
2877
2878    @unittest.skipIf(
2879        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2880    )
2881    def test_mkldnn_fuser_broadcasting(self):
2882        class Add(nn.Module):
2883            def __init__(self, tensor):
2884                super().__init__()
2885                self.tensor = tensor
2886
2887            def forward(self, x):
2888                return x + self.tensor
2889
2890        with set_default_dtype(torch.float):
2891            for add_inp in [8], [8, 8, 1]:
2892                mod = nn.Sequential(self.conv(), Add(torch.rand(add_inp))).eval()
2893                scripted_mod = torch.jit.script(mod)
2894                scripted_mod = torch.jit.freeze(scripted_mod)
2895                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2896                FileCheck().check("prim::BroadcastMKLDNNTensors").run(
2897                    scripted_mod.graph
2898                )
2899                inp = torch.rand([1, 8, 8, 8])
2900                self.assertEqual(scripted_mod(inp), mod(inp))
2901                self.assertEqual(scripted_mod(inp), mod(inp))
2902
2903                # for good measure, check that broadcasting does not work without this op
2904                # so we can remove the op if it ever gets supported
2905                with self.assertRaisesRegex(RuntimeError, ""):
2906                    (
2907                        torch.rand([1, 8, 8, 8]).to_mkldnn()
2908                        + torch.rand(add_inp).to_mkldnn()
2909                    )
2910
2911    @unittest.skipIf(
2912        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2913    )
2914    def test_mkldnn_inplace_removal(self):
2915        class AddMul(nn.Module):
2916            def __init__(self, tensor):
2917                super().__init__()
2918                self.tensor = tensor
2919
2920            def forward(self, x):
2921                return x.add_(self.tensor).div_(self.tensor) - 4
2922
2923        with set_default_dtype(torch.float):
2924            mod = nn.Sequential(self.conv(), AddMul(torch.rand([8]))).eval()
2925            scripted_mod = torch.jit.script(mod)
2926            scripted_mod = torch.jit.freeze(scripted_mod)
2927            self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
2928            # add gets uninplaced and reinplaced
2929            FileCheck().check("aten::to_mkldnn").check("aten::add_").check(
2930                "aten::div_"
2931            ).run(scripted_mod.graph)
2932            inp = torch.rand([1, 8, 8, 8])
2933            self.assertEqual(scripted_mod(inp), mod(inp))
2934            self.assertEqual(scripted_mod(inp), mod(inp))
2935
2936    @unittest.skipIf(
2937        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
2938    )
2939    @skipIfNoTorchVision
2940    def test_maxpool_mkldnn(self):
2941        with set_default_dtype(torch.float):
2942            model = torchvision.models.resnet18()
2943            sub_model = torch.nn.Sequential(
2944                model.conv1, model.bn1, model.relu, model.maxpool
2945            )
2946            mod = torch.jit.freeze(torch.jit.script(sub_model.eval()))
2947            (
2948                N,
2949                C,
2950                H,
2951                W,
2952            ) = (
2953                10,
2954                3,
2955                224,
2956                224,
2957            )
2958            inp = torch.randn(N, C, H, W)
2959            self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
2960            FileCheck().check("max_pool").check("to_dense").run(mod.graph)
2961            FileCheck().check_count("to_dense", 1, exactly=True).run(mod.graph)
2962            self.assertEqual(mod(inp), sub_model(inp))
2963
2964    @unittest.skipIf(torch.backends.mkldnn.is_available(), "Testing no mkldnn")
2965    def test_conv_to_mkldnn_no_mkldnn(self):
2966        # test no error when mkldnn not available
2967        with set_default_dtype(torch.float):
2968            mod = torch.jit.script(nn.Conv2d(3, 32, kernel_size=3, stride=2).eval())
2969            frozen = torch.jit.freeze(mod)
2970            self.run_pass("convert_frozen_ops_to_mkldnn", frozen.graph)
2971            inp = torch.rand([4, 3, 4, 4])
2972            self.assertEqual(frozen(inp), mod(inp))
2973
2974    @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN")
2975    def test_freeze_conv_relu_fusion(self):
2976        with set_default_dtype(torch.float):
2977            conv_bias = [True, False]
2978            conv_ops = [nn.Conv2d, nn.Conv3d]
2979            use_add_z = [True, False]
2980            use_tracing = [True, False]
2981            for use_bias, conv, add_z, tracing in product(
2982                conv_bias, conv_ops, use_add_z, use_tracing
2983            ):
2984
2985                class Net(nn.Module):
2986                    def __init__(self, in_channels, out_channels, **kwargs):
2987                        super().__init__()
2988                        self.conv = conv(
2989                            in_channels, out_channels, bias=use_bias, **kwargs
2990                        )
2991                        self.relu = nn.ReLU(inplace=True)
2992                        self.add_z = add_z
2993
2994                    def forward(self, x):
2995                        z = self.conv(x)
2996                        out = self.conv(x)
2997                        if self.add_z:
2998                            out += z
2999                        out = self.relu(out)
3000                        return out
3001
3002                mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()
3003
3004                inps = [5, 3, 4, 4]
3005                if conv == nn.Conv3d:
3006                    inps.append(inps[-1])
3007                inp = torch.rand(inps).cuda()
3008
3009                if tracing:
3010                    scripted_mod = torch.jit.trace(mod_eager, (inp))
3011                else:
3012                    scripted_mod = torch.jit.script(mod_eager)
3013
3014                frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
3015                if TEST_WITH_ROCM:
3016                    if add_z:
3017                        FileCheck().check("aten::miopen_convolution_add_relu").run(
3018                            frozen_mod.graph
3019                        )
3020                    else:
3021                        FileCheck().check("aten::miopen_convolution_relu").run(
3022                            frozen_mod.graph
3023                        )
3024                else:
3025                    if add_z:
3026                        FileCheck().check("aten::cudnn_convolution_add_relu").run(
3027                            frozen_mod.graph
3028                        )
3029                    else:
3030                        FileCheck().check("aten::cudnn_convolution_relu").run(
3031                            frozen_mod.graph
3032                        )
3033
3034                self.assertEqual(mod_eager(inp), frozen_mod(inp))
3035
3036    @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN")
3037    def test_freeze_conv_relu_fusion_not_forward(self):
3038        with set_default_dtype(torch.float):
3039
3040            class Net(nn.Module):
3041                def __init__(self, in_channels, out_channels, **kwargs):
3042                    super().__init__()
3043                    self.conv = nn.Conv2d(
3044                        in_channels, out_channels, bias=None, **kwargs
3045                    )
3046                    self.relu = nn.ReLU(inplace=True)
3047
3048                def forward(self, x):
3049                    z = self.conv(x)
3050                    out = self.conv(x)
3051                    out = self.relu(out)
3052                    return out
3053
3054                @torch.jit.export
3055                def make_prediction(self, x):
3056                    return self.forward(x)
3057
3058            mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()
3059
3060            inps = [5, 3, 4, 4]
3061            inp = torch.rand(inps).cuda()
3062
3063            scripted_mod = torch.jit.script(mod_eager)
3064
3065            frozen_mod = torch.jit.freeze(
3066                scripted_mod, preserved_attrs=["make_prediction"]
3067            )
3068            optimized_mod = torch.jit.optimize_for_inference(
3069                frozen_mod, other_methods=["make_prediction"]
3070            )
3071            if TEST_WITH_ROCM:
3072                FileCheck().check("aten::miopen_convolution_relu").run(
3073                    optimized_mod.make_prediction.graph
3074                )
3075            else:
3076                FileCheck().check("aten::cudnn_convolution_relu").run(
3077                    optimized_mod.make_prediction.graph
3078                )
3079
3080            self.assertEqual(
3081                mod_eager.make_prediction(inp), optimized_mod.make_prediction(inp)
3082            )
3083
3084    @unittest.skipIf(
3085        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3086    )
3087    def test_numel_less_than_size_with_padding(self):
3088        with set_default_dtype(torch.float):
3089
3090            class MyModule(nn.Module):
3091                def __init__(self) -> None:
3092                    super().__init__()
3093                    self.conv1 = nn.Conv2d(
3094                        1,
3095                        2,
3096                        kernel_size=(2, 4),
3097                        stride=2,
3098                        padding=2,
3099                        dilation=(2, 1),
3100                    )
3101
3102                def forward(self, i0):
3103                    x = self.conv1(i0)
3104                    o0 = torch.max(x, i0)
3105                    o1 = torch.clip(x, -1.5, 1.5)
3106                    return o0, o1
3107
3108            i0 = torch.zeros((1, 1, 1, 2), dtype=torch.float32)
3109            mod = MyModule()
3110            out = mod(i0)
3111
3112            exported = torch.jit.trace(mod, [i0])
3113            exported = torch.jit.optimize_for_inference(exported)
3114
3115            eout = exported(i0)
3116            self.assertTrue(all(torch.allclose(x, y) for x, y in zip(out, eout)))
3117
3118    @unittest.skipIf(
3119        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3120    )
3121    def test_incompatible_perf_formats(self):
3122        with set_default_dtype(torch.float):
3123
3124            class Mod(nn.Module):
3125                def __init__(self) -> None:
3126                    super().__init__()
3127                    self.conv = torch.nn.Conv2d(3, 64, 3, 2)
3128                    self.max_pool = torch.nn.MaxPool2d(111, 111)
3129
3130                def forward(self, x):
3131                    a = self.conv(x)
3132                    b = self.max_pool(a)
3133                    return a + b
3134
3135            model = Mod()
3136            model.eval()
3137            mod = torch.jit.freeze(torch.jit.script(model))
3138            (
3139                N,
3140                C,
3141                H,
3142                W,
3143            ) = (
3144                10,
3145                3,
3146                224,
3147                224,
3148            )
3149            inp = torch.randn(N, C, H, W)
3150            self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3151            self.assertEqual(model(inp), mod(inp))
3152
3153    @unittest.skipIf(
3154        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3155    )
3156    def test_pool2d_batchnorm(self):
3157        with set_default_dtype(torch.float):
3158            pooling_layers = [
3159                torch.nn.AdaptiveAvgPool2d(4),
3160                # torch.nn.AdaptiveMaxPool2d(4), # return tuples
3161                torch.nn.MaxPool2d(4),
3162                torch.nn.AvgPool2d(4),
3163                torch.nn.BatchNorm2d(64).eval(),
3164            ]
3165
3166            for pl in pooling_layers:
3167                sub_model = torch.nn.Sequential(
3168                    torch.nn.Conv2d(3, 64, 2, 2),
3169                    torch.nn.ReLU(),
3170                    pl,
3171                    torch.nn.Hardswish(),
3172                )
3173                sub_model.eval()
3174                mod = torch.jit.freeze(torch.jit.script(sub_model))
3175                (
3176                    N,
3177                    C,
3178                    H,
3179                    W,
3180                ) = (
3181                    10,
3182                    3,
3183                    224,
3184                    224,
3185                )
3186                inp = torch.randn(N, C, H, W)
3187                # these two passes needed to remove
3188                # a size check in BatchNorm2d
3189                removeExceptions(mod.graph)
3190                self.run_pass("dce", mod.graph)
3191                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3192                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
3193                self.assertEqual(sub_model(inp), mod(inp))
3194
3195    @unittest.skipIf(
3196        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3197    )
3198    def test_pool3d_batchnorm(self):
3199        with set_default_dtype(torch.float):
3200            pooling_layers = [
3201                torch.nn.MaxPool3d(4),
3202                # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings
3203                # torch.nn.AdaptiveMaxPool3d(4), # return tuples
3204                torch.nn.AvgPool3d(4),
3205                torch.nn.BatchNorm3d(64).eval(),
3206            ]
3207
3208            for pl in pooling_layers:
3209                sub_model = torch.nn.Sequential(
3210                    torch.nn.Conv3d(3, 64, 2, 2),
3211                    torch.nn.ReLU(),
3212                    pl,
3213                    torch.nn.Hardswish(),
3214                )
3215                sub_model.eval()
3216                mod = torch.jit.freeze(torch.jit.script(sub_model))
3217                N, C, H, W, D = 10, 3, 64, 64, 64
3218                inp = torch.randn(N, C, D, H, W)
3219                # these two passes needed to remove
3220                # a size check in BatchNorm2d
3221                removeExceptions(mod.graph)
3222                self.run_pass("dce", mod.graph)
3223                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3224                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
3225                self.assertEqual(sub_model(inp), mod(inp))
3226
3227    @unittest.skipIf(
3228        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3229    )
3230    @skipIfNoTorchVision
3231    def test_conv_hardswish(self):
3232        with set_default_dtype(torch.float):
3233
3234            class Clamp(torch.nn.Module):
3235                def __init__(self, min_val, max_val, **kwargs):
3236                    super().__init__()
3237                    self.min_val = min_val
3238                    self.max_val = max_val
3239
3240                def forward(self, x):
3241                    return torch.clamp(x, self.min_val, self.max_val)
3242
3243            (
3244                N,
3245                C,
3246                H,
3247                W,
3248            ) = (
3249                10,
3250                3,
3251                224,
3252                224,
3253            )
3254            activations = [
3255                torch.nn.Hardswish(),
3256                torch.nn.Hardsigmoid(),
3257                torch.nn.ReLU6(),
3258                torch.nn.Tanh(),
3259                torch.nn.Hardtanh(0.0, 6.0),
3260                torch.nn.Hardtanh(1.0, 100.0),
3261                torch.nn.Hardtanh(-100.0, -1.0),
3262                torch.nn.GELU(),
3263                Clamp(-100.0, -1.0),
3264                Clamp(1.0, 100.0),
3265                Clamp(0.0, 6.0),
3266                Clamp(-1.0, 0.0),
3267            ]
3268
3269            model = torchvision.models.resnet18()
3270            for activation in activations:
3271                sub_model = torch.nn.Sequential(model.conv1, activation)
3272                sub_model.eval()
3273                mod = torch.jit.freeze(torch.jit.script(sub_model))
3274                inp = torch.randn(N, C, H, W)
3275                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3276                FileCheck().check_count("aten::to_dense", 1, exactly=True).run(
3277                    mod.graph
3278                )
3279                self.assertEqual(sub_model(inp), mod(inp))
3280
3281    @unittest.skipIf(
3282        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3283    )
3284    def test_hardswish_hardsigmoid(self):
3285        with set_default_dtype(torch.float):
3286            op_map = {
3287                "prim::MKLDNNHardSwish": F.hardswish,
3288                "prim::MKLDNNHardSigmoid": F.hardsigmoid,
3289            }
3290
3291            input_sizes = ([0], [1], [3], [1, 3, 8, 8])
3292            for mkldnn_opname, aten_op in op_map.items():
3293                for size in input_sizes:
3294                    for inplace in (True, False):
3295                        inplace_str = "_" if inplace else ""
3296                        inplace_tgt = "%34" if inplace else "%35"
3297                        graph_str = f"""graph(%input.1 : Tensor):
3298                            %33 : None = prim::Constant()
3299                            %34 : Tensor = aten::to_mkldnn(%input.1, %33)
3300                            %35 : Tensor = {mkldnn_opname}{inplace_str}(%34)
3301                            return ({inplace_tgt})
3302                        """
3303                        g = torch._C.parse_ir(graph_str)
3304                        m = self.createFunctionFromGraph(g)
3305                        x = torch.rand(size)
3306                        # `inplace=False` is intentional, otherwise we modify the input
3307                        # and we aren't testing aten impls anyways
3308                        self.assertEqual(aten_op(x, inplace=False), m(x).to_dense())
3309
3310    @unittest.skipIf(
3311        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
3312    )
3313    def test_scalar_mul(self):
3314        with set_default_dtype(torch.float):
3315
3316            class Mod(nn.Module):
3317                def __init__(self) -> None:
3318                    super().__init__()
3319                    self.mod = nn.Conv2d(8, 8, 1, padding=1)
3320
3321                def forward(self, x):
3322                    a1 = self.mod(x) * 4
3323                    return a1 * 4 + a1 * 5.0
3324
3325            mod = Mod().eval()
3326            scripted = torch.jit.freeze(torch.jit.script(mod))
3327            optimized = torch.jit.optimize_for_inference(scripted)
3328            inp = torch.rand([1, 8, 8, 8])
3329            # a1 cant be inplaced for first use, can for second
3330            FileCheck().check("ScalarMul(").check("ScalarMul_").run(optimized.graph)
3331            self.assertEqual(optimized(inp), mod(inp))
3332
3333    def test_remove_detach(self):
3334        class Mod(nn.Module):
3335            def forward(self, x):
3336                y = x.detach()
3337                return y * y
3338
3339        mod = Mod().eval()
3340        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
3341        inp = torch.randn((2, 2))
3342        FileCheck().check_not("aten::detach").run(frozen_mod.graph)
3343        self.assertEqual(frozen_mod(inp), mod(inp))
3344
3345    def test_remove_detach_not_applied(self):
3346        class Mod(nn.Module):
3347            def forward(self, x):
3348                y = x.detach()
3349                return x is y
3350
3351        mod = Mod().eval()
3352        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
3353        inp = torch.randn((2, 2))
3354        FileCheck().check("aten::detach").run(frozen_mod.graph)
3355        self.assertEqual(frozen_mod(inp), mod(inp))
3356
3357
3358@skipIfTorchDynamo("somehow causing hanging during python shutdown")
3359@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
3360class TestMKLDNNReinplacing(JitTestCase):
3361    def setUp(self):
3362        super().setUp()
3363        self.default_dtype = torch.get_default_dtype()
3364        torch.set_default_dtype(torch.float)
3365
3366    def tearDown(self):
3367        super().tearDown()
3368        torch.set_default_dtype(self.default_dtype)
3369
3370    def getConv(self):
3371        return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()
3372
3373    def getInput(self):
3374        return torch.rand([4, 3, 4, 4])
3375
3376    def freezeAndConvert(self, mod):
3377        mod = torch.jit.freeze(torch.jit.script(mod.eval()))
3378        self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
3379        return mod
3380
3381    def checkResults(self, mod1, mod2):
3382        inp = self.getInput()
3383        self.assertEqual(mod1(inp), mod2(inp))
3384
3385    def test_successful(self):
3386        # simple conv-relu
3387
3388        mod_eager = nn.Sequential(self.getConv(), nn.Hardswish(), nn.ReLU())
3389        mod = self.freezeAndConvert(mod_eager)
3390        FileCheck().check("mkldnn_convolution").check_next(
3391            "prim::MKLDNNHardSwish_"
3392        ).check_next("aten::relu_").run(mod.graph)
3393        self.checkResults(mod_eager, mod)
3394
3395    def test_merge_liveness(self):
3396        class Mod(nn.Module):
3397            def __init__(self, tensor):
3398                super().__init__()
3399                self.tensor = tensor
3400
3401            def forward(self, x):
3402                # this mul can be inplaced since x is dead after this use
3403                temporary = x * self.tensor
3404                # temporary livespan is the return node,
3405                # add can not be inplaced
3406                return temporary + temporary, temporary
3407
3408        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
3409        mod = self.freezeAndConvert(mod_eager)
3410        FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph)
3411        self.checkResults(mod_eager, mod)
3412
3413    def test_always_alive_values(self):
3414        class Mod(nn.Module):
3415            def __init__(self, tensor):
3416                super().__init__()
3417                self.tensor = tensor
3418
3419            def forward(self, x):
3420                # x can't be inplaced because its a return value,
3421                # check that the inplacing pass doesnt try to inplace
3422                # self.tensor because its always alive
3423                return x * self.tensor, x
3424
3425        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
3426        mod = self.freezeAndConvert(mod_eager)
3427        FileCheck().check_not("aten::mul_").run(mod.graph)
3428        self.checkResults(mod_eager, mod)
3429
3430        conv = self.getConv()
3431
3432        class Mod(nn.Module):
3433            def __init__(self) -> None:
3434                super().__init__()
3435                self.tensor = torch.rand([4, 32, 1, 1])
3436                self.conv = conv
3437
3438            def forward(self, x):
3439                # the shapes dont add up on this just testing a particular pattern
3440                conv_output = self.conv(x)
3441                return conv_output, self.conv(torch.add(x, x))
3442
3443        mod = self.freezeAndConvert(Mod())
3444        # x is an input to the graph, and so it should not be inplaced
3445        # in the torch.add(x, x) call
3446        FileCheck().check_not("aten::add_").run(mod.graph)
3447
3448    def test_switch_inputs_to_inplace(self):
3449        class Mod(nn.Module):
3450            def __init__(self, tensor):
3451                super().__init__()
3452                self.tensor = tensor
3453
3454            def forward(self, x):
3455                # self.tensor cannot be inplaced, however x can,
3456                # and bc add is commutative we can reverse inputs to add_
3457                return self.tensor + x
3458
3459        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
3460        mod = self.freezeAndConvert(mod_eager)
3461        FileCheck().check("aten::add_").run(mod.graph)
3462        self.checkResults(mod_eager, mod)
3463