xref: /aosp_15_r20/external/pytorch/test/inductor/test_binary_folding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import functools
3import importlib
4import itertools
5import os
6import sys
7
8import torch
9from torch import nn
10from torch._inductor import config as inductor_config
11from torch.testing._internal.common_cuda import TEST_CUDNN
12
13
14# Make the helper files in test/ importable
15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16sys.path.append(pytorch_test_dir)
17
18from inductor.test_inductor_freezing import TestCase
19from inductor.test_torchinductor import check_model, check_model_gpu, copy_tests
20from torch.testing._internal.common_utils import TEST_WITH_ASAN
21from torch.testing._internal.inductor_utils import skipCUDAIf
22
23
24importlib.import_module("functorch")
25importlib.import_module("filelock")
26
27from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
28
29
30aten = torch.ops.aten
31
32
33class BinaryFoldingTemplate(TestCase):
34    @skipCUDAIf(TEST_CUDNN, "CUDNN has accuracy issues for this test")
35    def test_conv_binary_folding(self):
36        @torch.no_grad()
37        def test_conv_fusion(use_bias, module, op, scalar, add_tensor, expect_success):
38            class ConvOp(nn.Module):
39                __constants__ = ["use_scalar"]
40
41                def __init__(self, in_channels, out_channels, device, **kwargs):
42                    super().__init__()
43                    self.conv = module(
44                        in_channels, out_channels, bias=use_bias, **kwargs
45                    ).to(device)
46                    self.conv2 = module(
47                        in_channels, out_channels, bias=use_bias, **kwargs
48                    ).to(device)
49                    self.use_scalar = scalar
50                    tensor_size = [1 for _ in range(self.conv.weight.ndim)]
51                    tensor_size[1] = self.conv.weight.size(0)
52                    self.tensor = torch.nn.Parameter(
53                        add_tensor
54                        if add_tensor is not None
55                        else torch.rand(tensor_size).to(device)
56                    )
57                    self.op = op
58
59                def forward(self, x):
60                    x = self.conv(x)
61                    if self.use_scalar:
62                        return self.op(x, 2.0)
63                    else:
64                        return self.op(x, self.tensor)
65
66            from torch._inductor.compile_fx import compile_fx, compile_fx_inner
67
68            aten_binary = {
69                torch.add: aten.add.Tensor,
70                torch.sub: aten.sub.Tensor,
71                torch.mul: aten.mul.Tensor,
72                torch.div: aten.div.Tensor,
73            }
74            n_binary_ops = 0
75
76            def my_inner_compile(gm, example_inputs, *args, **kwargs):
77                out = compile_fx_inner(gm, example_inputs, *args, **kwargs)
78                nonlocal n_binary_ops
79                binarry_ops = [n for n in gm.graph.nodes if n.target == aten_binary[op]]
80                n_binary_ops += len(binarry_ops)
81                return out
82
83            torch._dynamo.reset()
84            mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval()
85            out_optimized = torch.compile(
86                mod_eager,
87                backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
88            )
89
90            inps = [4, 3, 4]
91            if module == nn.Conv2d:
92                inps.append(inps[-1])
93            if module == nn.Conv3d:
94                inps.append(inps[-1])
95                inps.append(inps[-1])
96
97            torch.manual_seed(1234)
98            inp = torch.rand(inps).to(self.device)
99            out_eager = mod_eager(inp)
100            out_optimized = out_optimized(inp)
101            self.assertEqual(out_optimized, out_eager)
102            if expect_success:
103                self.assertTrue(n_binary_ops == 0)
104            else:
105                self.assertTrue(n_binary_ops == 1)
106
107        conv_bias = [True, False]
108        modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
109        use_scalar = [True, False]
110        ops = [torch.add, torch.sub, torch.mul, torch.div]
111        for use_bias, module, pytorch_op, scalar in itertools.product(
112            conv_bias, modules, ops, use_scalar
113        ):
114            # TODO: support scalar case
115            expect_success = not scalar
116            test_conv_fusion(
117                use_bias,
118                module,
119                pytorch_op,
120                scalar,
121                add_tensor=None,
122                expect_success=expect_success,
123            )
124
125        for use_bias, pytorch_op in itertools.product(conv_bias, ops):
126            # broadcasting add
127            test_conv_fusion(
128                use_bias,
129                nn.Conv2d,
130                pytorch_op,
131                False,
132                add_tensor=torch.rand(
133                    32,
134                    1,
135                    32,
136                ).to(self.device),
137                expect_success=False,
138            )
139
140            # broadcasting add
141            test_conv_fusion(
142                use_bias,
143                nn.Conv2d,
144                pytorch_op,
145                False,
146                add_tensor=torch.rand(1, 1).to(self.device),
147                expect_success=True,
148            )
149
150            # add with different dtype
151            test_conv_fusion(
152                use_bias,
153                nn.Conv2d,
154                pytorch_op,
155                False,
156                add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
157                expect_success=False,
158            )
159
160    @inductor_config.patch({"freezing": True})
161    def test_conv_bn_folding(self):
162        @torch.no_grad()
163        def test_conv_fusion(use_bias, module, expect_success):
164            class ConvOp(nn.Module):
165                def __init__(self, in_channels, out_channels, device, **kwargs):
166                    super().__init__()
167                    self.conv = module[0](
168                        in_channels, out_channels, bias=use_bias, **kwargs
169                    ).to(device)
170                    self.bn = module[1](out_channels).to(device)
171
172                def forward(self, x):
173                    x = self.conv(x)
174                    return self.bn(x)
175
176            from torch._inductor.compile_fx import compile_fx, compile_fx_inner
177
178            aten_binary = [
179                aten.add.Tensor,
180                aten.sub.Tensor,
181                aten.mul.Tensor,
182                aten.div.Tensor,
183            ]
184            n_binary_ops = 0
185
186            def my_inner_compile(gm, example_inputs, *args, **kwargs):
187                out = compile_fx_inner(gm, example_inputs, *args, **kwargs)
188                nonlocal n_binary_ops
189                binarry_ops = [n for n in gm.graph.nodes if n.target in aten_binary]
190                n_binary_ops += len(binarry_ops)
191                return out
192
193            torch._dynamo.reset()
194            mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval()
195            out_optimized = torch.compile(
196                mod_eager,
197                backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
198            )
199
200            inps = [4, 3, 4]
201            if module[0] == nn.Conv2d:
202                inps.append(inps[-1])
203            if module[0] == nn.Conv3d:
204                inps.append(inps[-1])
205                inps.append(inps[-1])
206
207            inp = torch.rand(inps).to(self.device)
208            out_eager = mod_eager(inp)
209            out_optimized = out_optimized(inp)
210            self.assertEqual(out_optimized, out_eager, atol=2e-04, rtol=1e-5)
211            if expect_success:
212                self.assertTrue(n_binary_ops == 0)
213            else:
214                self.assertTrue(n_binary_ops > 1)
215
216        conv_bias = [True, False]
217        modules = [
218            (nn.Conv1d, nn.BatchNorm1d),
219            (nn.Conv2d, nn.BatchNorm2d),
220            (nn.Conv3d, nn.BatchNorm3d),
221        ]
222        for use_bias, module in itertools.product(conv_bias, modules):
223            test_conv_fusion(
224                use_bias,
225                module,
226                expect_success=True,
227            )
228
229
230if HAS_CPU and not torch.backends.mps.is_available():
231
232    class FreezingCpuTests(TestCase):
233        common = check_model
234        device = "cpu"
235        autocast = torch.cpu.amp.autocast
236
237    copy_tests(BinaryFoldingTemplate, FreezingCpuTests, "cpu")
238
239if HAS_GPU and not TEST_WITH_ASAN:
240
241    class FreezingGpuTests(TestCase):
242        common = check_model_gpu
243        device = GPU_TYPE
244        autocast = torch.amp.autocast(device_type=GPU_TYPE)
245
246    copy_tests(BinaryFoldingTemplate, FreezingGpuTests, GPU_TYPE)
247
248
249del BinaryFoldingTemplate
250
251if __name__ == "__main__":
252    from torch._inductor.test_case import run_tests
253
254    if HAS_CPU or HAS_GPU:
255        run_tests(needs="filelock")
256