xref: /aosp_15_r20/external/pytorch/test/inductor/test_layout_optim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import copy
3import os
4import random
5
6import torch
7from torch import nn
8from torch._dynamo.utils import same
9from torch._inductor import config
10from torch._inductor.test_case import run_tests, TestCase
11from torch.testing._internal.common_cuda import tf32_off
12from torch.testing._internal.inductor_utils import HAS_CUDA
13
14
15USE_DDP_WRAPPER = os.environ.get("USE_DDP_WRAPPER", "1") == "1"
16
17
18class Model2Conv(nn.Module):
19    def __init__(self, dim=512, manual_graph_break=False):
20        super().__init__()
21        self.conv1 = nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
22        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
23        self.manual_graph_break = manual_graph_break
24
25    def forward(self, x):
26        x = self.conv1(x)
27        if self.manual_graph_break:
28            torch._dynamo.graph_break()
29        x = self.conv2(x)
30        return x
31
32    def get_example_inputs(self):
33        return (torch.rand(2, 3, 16, 16),)
34
35
36class TestLayoutOptim(TestCase):
37    @classmethod
38    def setUpClass(cls):
39        super().setUpClass()
40
41        import torch.distributed as dist
42
43        # not use a fixed port for stress test
44        tot_retry = 5
45        for retry_no in range(tot_retry):
46            try:
47                port = random.randint(10000, 60000)
48                dist.init_process_group(
49                    backend="nccl",
50                    init_method=f"tcp://localhost:{port}",
51                    world_size=1,
52                    rank=0,
53                )
54                break
55            except RuntimeError:
56                if retry_no == tot_retry - 1:
57                    raise
58                else:
59                    continue
60
61    def verify_accuracy(
62        self, model_class, use_ddp_wrapper=USE_DDP_WRAPPER, is_train=False
63    ):
64        # there are 2 potential ways to introduce graph breaks
65        # 1. manually
66        # 2. using DDP
67        # if we are not using DDP to introduce graph breaks, do that manually
68        def wrap_mod(m):
69            if is_train:
70
71                def f(*inp):
72                    x = m(*inp)
73                    x.sum().backward()
74
75                    grads = []
76                    for name, param in m.named_parameters():
77                        grad = param.grad
78                        if param.grad is None:
79                            grad = torch.zeros_like(param)
80                        grads.append(grad)
81                    return grads
82
83                return f
84            else:
85                return m
86
87        manual_graph_break = not use_ddp_wrapper
88        mod = model_class(manual_graph_break=manual_graph_break).cuda()
89        inp = [t.cuda() for t in mod.get_example_inputs()]
90        expected_out = wrap_mod(mod)(*inp)
91
92        fp64_mod = copy.deepcopy(mod).to(torch.float64)
93        fp64_inp = [t.to(torch.float64) for t in copy.deepcopy(inp)]
94        fp64_out = wrap_mod(fp64_mod)(*fp64_inp)
95
96        if use_ddp_wrapper:
97            from torch.nn.parallel import DistributedDataParallel as DDP
98
99            ddp_wrapped_mod = DDP(mod)
100            opt_mod = torch.compile(wrap_mod(ddp_wrapped_mod))
101        else:
102            opt_mod = torch.compile(wrap_mod(mod))
103        actual_out = opt_mod(*inp)
104
105        if is_train:
106            self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out))
107        else:
108            expected_sum = expected_out.sum()
109            actual_sum = actual_out.sum()
110            print(f"Expected sum {expected_sum}, actual sum {actual_sum}")
111            self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out))
112
113    def verify_accuracy_for_infer(self, *args, **kwargs):
114        self.verify_accuracy(*args, **kwargs, is_train=False)
115
116    def verify_accuracy_for_train(self, *args, **kwargs):
117        self.verify_accuracy(*args, **kwargs, is_train=True)
118
119    def test_2conv_with_graph_break(self):
120        """
121        Make sure graph break does not cause any accuracy issue.
122        """
123        self.verify_accuracy_for_infer(Model2Conv)
124
125    def test_3conv_with_graph_break(self):
126        class Model(nn.Module):
127            def __init__(
128                self, dim=512, patch_size=7, kernel_size=7, manual_graph_break=False
129            ):
130                super().__init__()
131                self.seq = nn.Sequential(
132                    nn.Conv2d(
133                        3, dim, kernel_size=patch_size, stride=patch_size, bias=False
134                    ),
135                    nn.Conv2d(
136                        dim, dim, kernel_size, groups=dim, padding="same", bias=False
137                    ),
138                )
139                self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=False)
140                self.manual_graph_break = manual_graph_break
141
142            def forward(self, x):
143                x = self.seq(x)
144                if self.manual_graph_break:
145                    torch._dynamo.graph_break()
146                x = self.conv(x)
147                return x
148
149            def get_example_inputs(self):
150                return (torch.randn(2, 3, 16, 16),)
151
152        self.verify_accuracy_for_infer(Model)
153
154    @torch.no_grad()
155    def test_keep_output_layout_infer(self):
156        class Model(nn.Module):
157            def __init__(self) -> None:
158                super().__init__()
159                self.conv = nn.Conv2d(
160                    3, 128, kernel_size=3, padding=1, stride=1, bias=False
161                )
162
163            def forward(self, x):
164                x = self.conv(x)
165                return x
166
167            def get_example_inputs(self):
168                return (torch.randn(2, 3, 5, 5),)
169
170        mod = Model().cuda()
171        inp = [t.cuda() for t in mod.get_example_inputs()]
172        out = mod(*inp)
173
174        opt_mod = torch.compile(mod)
175        opt_out = opt_mod(*inp)
176
177        # We should be able to do view on eager output
178        out.view(5, -1)
179
180        # We should be able to do view on the output of the optimized module
181        # Note that if the output is channels last, the view op will fail.
182        opt_out.view(5, -1)
183
184    def test_keep_output_layout_with_freezing(self):
185        with config.patch(
186            {
187                "freezing": True,
188            }
189        ):
190            self.test_keep_output_layout_infer()
191
192    def test_training_acc(self):
193        self.verify_accuracy_for_train(Model2Conv)
194
195    def test_mutate_view(self):
196        """
197        The GraphModule passed to GraphLowering init method is like:
198        https://gist.github.com/shunting314/07228313fd017e2267101ff32edc6d64
199
200        It shows that we will call copy_ to update the argument in the end. This
201        guarantees the correctnesss.
202        """
203
204        @torch.compile
205        def f(x):
206            y = x.view(3, 2)
207            y.mul_(2)
208
209        x = torch.ones(2, 3).cuda()
210        f(x)
211        self.assertTrue(torch.equal(x, torch.ones(2, 3).cuda() * 2))
212
213    def test_mutate_base(self):
214        """
215        The GraphModule passed to GraphLowering init method is like:
216        https://gist.github.com/shunting314/fd60fe11d1f844c6db76aba7b06811bc
217
218        It shows that the output of the graph is the mul node which contains
219        the update we applied to the base tensor.
220        """
221
222        @torch.compile
223        def f(x):
224            y = x.view(3, 2)
225            x.mul_(2)
226            return y
227
228        x = torch.ones(2, 3).cuda()
229        y = f(x)
230        self.assertTrue(torch.equal(y, torch.ones(3, 2).cuda() * 2))
231
232    @tf32_off()
233    def test_mutate_base_for_conv_output(self):
234        class Model(nn.Module):
235            def __init__(self, manual_graph_break=False):
236                super().__init__()
237                self.conv = nn.Conv2d(3, 512, kernel_size=3, stride=2, bias=False)
238
239            def forward(self, x):
240                x = self.conv(x)
241                y = x.view(-1)
242                x.mul_(2)
243                return y
244
245            def get_example_inputs(self):
246                return (torch.rand(2, 3, 16, 16),)
247
248        self.verify_accuracy_for_infer(Model)
249
250    @tf32_off()
251    def test_mutate_view_for_conv_output(self):
252        class Model(nn.Module):
253            def __init__(self, manual_graph_break=False):
254                super().__init__()
255                self.conv = nn.Conv2d(3, 512, kernel_size=3, stride=2, bias=False)
256
257            def forward(self, x):
258                x = self.conv(x)
259                y = x.view(-1)
260                y.mul_(2)
261                return x
262
263            def get_example_inputs(self):
264                return (torch.rand(2, 3, 16, 16),)
265
266        self.verify_accuracy_for_infer(Model)
267
268    def test_dynamic_shape_specialization(self):
269        """
270        Previously in aot_autograd.py we compare strides of FakeTensor
271        with real tensor. That cause dynamic dimensions of the FakeTensor
272        being specialized to static shapes. This test protects against that.
273        """
274
275        def f(a, b):
276            x = a.sin()
277            y = b.cos()
278            z = x + y
279            return z
280
281        for size in [4, 8, 16]:
282            a = torch.randn(2, size, requires_grad=True).cuda()
283            b = torch.randn(2, size).cuda()
284            actual = torch.compile(f, dynamic=True)(a, b)
285            self.assertTrue(torch.allclose(f(a, b), actual))
286
287            # Trigger the compiling of the backward graph
288            actual.sum().backward()
289
290    def test_nll_loss_backward(self):
291        """
292        Repro for issue https://github.com/pytorch/pytorch/issues/120759
293
294        The CUDA implementation of aten.nll_loss2d_backward.default requires
295        the self tensor (whose layout will be used to create grad_input)
296        to be contiguous. Layout optimization may change the self tensor's layout
297        and cause failure. We fix that by adding layout constaints to the
298        fallback of aten.nll_loss2d_backward.default .
299        """
300
301        class MyModel(torch.nn.Module):
302            def __init__(self, input_dim, num_classes):
303                super().__init__()
304                self.conv = torch.nn.Conv2d(1, num_classes, 3, 1, padding="same")
305                self.out = torch.nn.Linear(input_dim * num_classes, num_classes)
306
307            def forward(self, x: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
308                x = self.conv(x)
309                b, c, t, f = x.size()
310                x = self.out(x.reshape(b, t, c * f))
311                logits = x.reshape(x.size(0), x.size(2), x.size(1))
312                loss = torch.nn.functional.cross_entropy(logits, targets)
313                return loss
314
315        device = "cuda"
316        batch_size = 48
317        seq_len = 144
318        input_dim = 39
319        num_classes = 111
320
321        model = MyModel(input_dim, num_classes)
322        model.to(device)
323
324        opt_model = torch.compile(model)
325
326        x = torch.ones((batch_size, 1, seq_len, input_dim), device=device)
327        targets = torch.randint(
328            0, num_classes - 1, (batch_size, seq_len), device=device, dtype=torch.int64
329        )
330
331        loss = model(x, targets)
332        loss.backward()
333
334        ref = model(x, targets)
335        self.assertTrue(torch.allclose(ref, loss))
336
337
338if __name__ == "__main__":
339    if HAS_CUDA:
340        run_tests()
341