xref: /aosp_15_r20/external/pytorch/test/inductor/test_padding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import copy
3import functools
4import os
5import unittest
6from typing import Tuple
7
8import torch
9from torch import nn, Tensor
10from torch._dynamo.convert_frame import maybe_cprofile
11from torch._dynamo.test_case import run_tests, TestCase
12from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss
13from torch._inductor import config, ir, metrics
14from torch._inductor.fx_passes import pad_mm as pad_mm_pass
15from torch._inductor.runtime.benchmarking import benchmarker
16from torch._inductor.utils import ceildiv, run_and_get_code
17from torch.testing._internal.common_utils import (
18    instantiate_parametrized_tests,
19    parametrize,
20    requires_cuda,
21    serialTest,
22)
23from torch.testing._internal.inductor_utils import HAS_CUDA
24
25
26DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
27DO_ACC_TEST = os.environ.get("DO_ACC_TEST", "1") == "1"
28WITH_STACK = os.environ.get("WITH_STACK") == "1"
29USE_CUDA_GRAPHS = os.environ.get("USE_CUDA_GRAPHS", "1") == "1"
30
31try:
32    import transformers  # noqa: F401
33
34    HAS_TRANSFORMER = True
35except ImportError:
36    HAS_TRANSFORMER = False
37
38
39def get_optim(m):
40    return torch.optim.Adam(m.parameters(), lr=0.01, capturable=True, foreach=True)
41
42
43def gen_transformer_inputs(vocab_size, bs, seq_length):
44    def geninp():
45        return torch.randint(
46            0, vocab_size, (bs, seq_length), dtype=torch.int64, requires_grad=False
47        )
48
49    input_dict = {"input_ids": geninp(), "labels": geninp()}
50    return input_dict
51
52
53class LinearAndSoftmax(nn.Module):
54    """
55    It's very common that a transformer model will do a matmul and then
56    softmax/log_softmax in the end.
57
58    Creating this toy model to capture the pattern and make sure we do
59    proper padding.
60    """
61
62    def __init__(self, vocab_size=30523, bias=True):
63        """
64        The default vocab size for BertForMaskedLM is 30522.
65        We run a few test cases with good or bad vocab_size around Bert's
66        default value.
67        """
68        super().__init__()
69        self.vocab_size = vocab_size
70        self.linear = nn.Linear(768, vocab_size, bias=bias)
71        self.ce = nn.CrossEntropyLoss()
72
73    def forward(self, x, label):
74        x = self.linear(x)
75        return self.ce(x.view(-1, self.vocab_size), label.view(-1))
76
77    def get_example_inputs(self, batch_size=16):
78        return torch.randn(batch_size, 512, 768), torch.randint(
79            0, self.vocab_size, (batch_size, 512)
80        )
81
82
83def forward_and_backward_pass(m, inputs):
84    m(*inputs).sum().backward()
85
86
87@config.patch(
88    {
89        "benchmark_kernel": True,
90        "triton.unique_kernel_names": True,
91        "triton.cudagraphs": USE_CUDA_GRAPHS,
92    }
93)
94@requires_cuda
95class TestCaseBase(TestCase):
96    @classmethod
97    def setUpClass(cls):
98        if HAS_CUDA:
99            cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision()
100            cls.prior_default_device = torch.get_default_device()
101            torch.set_float32_matmul_precision("high")
102            torch.set_default_device("cuda")
103
104    @classmethod
105    def tearDownClass(cls):
106        if HAS_CUDA:
107            torch.set_float32_matmul_precision(cls.prior_float32_matmul_precision)
108            torch.set_default_device(cls.prior_default_device)
109
110            cls.prior_float32_matmul_precision = None
111            cls.prior_default_device = None
112
113    def check_close(self, ref, act, tol=1e-3):
114        if type(ref).__name__ == "LongformerMaskedLMOutput":
115            ref = ref.loss
116            act = act.loss
117        if type(ref).__name__ == "SequenceClassifierOutput":
118            ref = ref.logits
119            act = act.logits
120        if isinstance(ref, dict) and "loss" in ref:
121            ref = ref["loss"]
122            act = act["loss"]
123        self.assertTrue(
124            torch.allclose(ref, act, atol=tol, rtol=tol), f"ref:\n{ref}\nact:\n{act}"
125        )
126
127    def common_numeric_check(self, f, *args, tol=1e-3, **kwargs):
128        ref = f(*args, **kwargs)
129        opt_f = torch.compile(f)
130        act = opt_f(*args, **kwargs)
131        self.check_close(ref, act, tol)
132
133    def do_profiling(
134        self,
135        f_lhs,
136        f_rhs,
137        tag_lhs="With padding",
138        tag_rhs="Without padding",
139        args=(),
140        kwargs=None,
141    ):
142        if kwargs is None:
143            kwargs = {}
144        torch.cuda.synchronize()
145        with torch.profiler.profile(with_stack=WITH_STACK) as p:
146            niter = 3
147            for _ in range(niter):
148                with torch.profiler.record_function(tag_lhs):
149                    f_lhs(*args, **kwargs)
150
151                with torch.profiler.record_function(tag_rhs):
152                    f_rhs(*args, **kwargs)
153            torch.cuda.synchronize()
154
155        profile_path = "/tmp/chrome.json"
156        p.export_chrome_trace(profile_path)
157        print(f"Chrome trace is written to {profile_path}")
158
159
160class PerfTestBetweenGoodAndBadShape(TestCaseBase):
161    @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled")
162    def test_nobias_LinearAndSoftmax_both_shapes(self):
163        self.test_LinearAndSoftmax_both_shapes(bias=False)
164
165    @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled")
166    def test_LinearAndSoftmax_both_shapes(self, bias=True):
167        """
168        Compare the perf with good and bad shape.
169        """
170        m_bad_shape = LinearAndSoftmax(vocab_size=30523, bias=bias)
171        inptus_bad_shape = m_bad_shape.get_example_inputs()
172        m_good_shape = LinearAndSoftmax(vocab_size=30528, bias=bias)
173        inputs_good_shape = m_good_shape.get_example_inputs()
174
175        m_bad_shape_opt = torch.compile(m_bad_shape)
176        m_good_shape_opt = torch.compile(m_good_shape)
177
178        latency_good_shape = benchmarker.benchmark_gpu(
179            lambda: forward_and_backward_pass(m_good_shape_opt, inputs_good_shape)
180        )
181        latency_bad_shape = benchmarker.benchmark_gpu(
182            lambda: forward_and_backward_pass(m_bad_shape_opt, inptus_bad_shape)
183        )
184        print(
185            f"Latency for good shape v.s. bad shape: {latency_good_shape:.3f}ms v.s. {latency_bad_shape:.3f}ms"
186        )
187
188    @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled")
189    def test_BertForMaskedLM(self, num_layers=1):
190        """
191        Compare the perf between doing padding and good shape.
192        """
193        from transformers import BertForMaskedLM
194
195        config_cls = BertForMaskedLM.config_class
196        bs = 16
197        seq_length = 512
198
199        def create_model(vocab_size):
200            config = config_cls()
201            config.num_hidden_layers = num_layers
202            config.vocab_size = vocab_size
203            inputs = gen_transformer_inputs(config.vocab_size, bs, seq_length)
204            model = BertForMaskedLM(config)
205
206            optim = get_optim(model)
207
208            def f(**inputs):
209                optim.zero_grad(True)
210                with torch.cuda.amp.autocast():
211                    pred = model(**inputs)
212                    loss = pred[0]
213                loss.backward()
214                optim.step()
215
216            return torch.compile(f), inputs
217
218        f_good_shape, inputs_good_shape = create_model(30528)
219        f_bad_shape, inputs_bad_shape = create_model(30522)
220
221        print("benchmark for good shape")
222        latency_good_shape = benchmarker.benchmark_gpu(
223            lambda: f_good_shape(**inputs_good_shape)
224        )
225        print("benchmark for bad shape")
226        latency_bad_shape = benchmarker.benchmark_gpu(
227            lambda: f_bad_shape(**inputs_bad_shape)
228        )
229        print(
230            f"Latency with good and bad shape: {latency_good_shape:.3f} v.s. {latency_bad_shape:.3f}"
231        )
232
233        self.do_profiling(
234            lambda: f_good_shape(**inputs_good_shape),
235            lambda: f_bad_shape(**inputs_bad_shape),
236            tag_lhs="With good shape",
237            tag_rhs="With bad shape",
238        )
239
240
241class PerfTestWithAndWithoutPadding(TestCaseBase):
242    @maybe_cprofile
243    def run_acc_and_perf_test(self, model, inputs, perf_inputs=None, tol=1e-3):
244        """
245        Run accuracy test.
246
247        Also compare the perf with and without the comprehensive padding if
248        DO_PERF_TEST is true.
249        """
250        if perf_inputs is None:
251            perf_inputs = inputs
252
253        def _process_inputs(x):
254            """
255            return args and kwargs
256            """
257            if isinstance(x, dict):
258                return [], x
259
260            if not isinstance(inputs, (tuple, list)):
261                x = [x]
262
263            return x, {}
264
265        args, kwargs = _process_inputs(inputs)
266        perf_args, perf_kwargs = _process_inputs(perf_inputs)
267
268        if DO_ACC_TEST:
269            model.eval()
270            self.common_numeric_check(model, *args, **kwargs, tol=tol)
271        else:
272            print("Accuracy test skipped")
273
274        model.train()
275
276        if DO_PERF_TEST:
277            print("Do performance test")
278
279            def get_f(m, optim):
280                def f(*args, **kwargs):
281                    optim.zero_grad(True)
282                    with torch.cuda.amp.autocast():
283                        pred = m(*args, **kwargs)
284                        loss = reduce_to_scalar_loss(pred)
285                    loss.backward()
286                    optim.step()
287
288                return f
289
290            latency_with_padding = None
291            print("Benchmark with padding")
292            with config.patch(comprehensive_padding=True):
293                m_copy_with_padding = copy.deepcopy(model)
294                optim_with_padding = get_optim(m_copy_with_padding)
295                opt_f_with_padding = torch.compile(
296                    get_f(m_copy_with_padding, optim_with_padding)
297                )
298                latency_with_padding = benchmarker.benchmark_gpu(
299                    lambda: opt_f_with_padding(*perf_args, **perf_kwargs)
300                )
301            latency_without_padding = None
302            print("bencmark without padding")
303            with config.patch(comprehensive_padding=False):
304                m_copy_without_padding = copy.deepcopy(model)
305                optim_without_padding = get_optim(m_copy_without_padding)
306                opt_f_without_padding = torch.compile(
307                    get_f(m_copy_without_padding, optim_without_padding)
308                )
309                latency_without_padding = benchmarker.benchmark_gpu(
310                    lambda: opt_f_without_padding(*perf_args, **perf_kwargs)
311                )
312            print(
313                f"Latency with and without padding: {latency_with_padding:.3f} v.s. {latency_without_padding:.3f}"
314            )
315
316            # profiling
317            self.do_profiling(
318                opt_f_with_padding,
319                opt_f_without_padding,
320                args=perf_args,
321                kwargs=perf_kwargs,
322            )
323
324    def test_nvidia_deeprecommender(self):
325        """
326        Compared the perf with and without comprehensive padding.
327        """
328        layer_sizes = [197951, 512, 512, 1024, 512, 512, 197951]
329        x = torch.randn(4, layer_sizes[0])
330
331        class Model(nn.Module):
332            def __init__(self) -> None:
333                super().__init__()
334                mod_list = []
335                for i in range(len(layer_sizes) - 1):
336                    mod_list.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
337                    mod_list.append(nn.SELU())
338
339                    if i == 2:
340                        mod_list.append(nn.Dropout(0.8))
341                self.seq = nn.Sequential(*mod_list)
342
343            def forward(self, x):
344                return self.seq(x)
345
346        m = Model()
347        perf_inputs = torch.randn(256, layer_sizes[0])
348        self.run_acc_and_perf_test(m, x, perf_inputs)
349
350    @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled")
351    def test_longformer(self, bs=4):
352        from transformers import AutoConfig, AutoModelForMaskedLM
353
354        config = AutoConfig.from_pretrained("allenai/longformer-base-4096")
355        model = AutoModelForMaskedLM.from_config(config)
356
357        vocab_size = model.config.vocab_size
358        seq_length = 1024
359        input_dict = gen_transformer_inputs(vocab_size, bs, seq_length)
360
361        self.run_acc_and_perf_test(model, input_dict)
362
363    @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled")
364    def test_longformer_small_bs(self):
365        """
366        The model exists in both HF and TB. In TB it uses a samller batch size.
367        """
368        self.test_longformer(bs=2)
369
370
371@instantiate_parametrized_tests
372class PaddingTest(TestCaseBase):
373    @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled")
374    def test_mm_padding_perf(self):
375        def naive_mm(a, b):
376            return a @ b
377
378        def _compute_padding(s, align):
379            return (s + align - 1) // align * align - s
380
381        @torch.compile
382        def pad_mm(a, b, align=16):
383            """
384            NOTE: this function only pad a single dimension which is good
385            enough for testing.
386            """
387            m_padding = _compute_padding(a.size(0), align)
388            k_padding = _compute_padding(a.size(1), align)
389            n_padding = _compute_padding(b.size(1), align)
390            return pad_mm_pass.pad_mm(a, b, m_padding, k_padding, n_padding)
391
392        for M, K, N, f in (
393            (8192, 768, 30523, naive_mm),
394            (8192, 768, 30523, pad_mm),
395            (8192, 768, 30528, naive_mm),
396            (30523, 8192, 768, naive_mm),
397            (30528, 8192, 768, naive_mm),
398        ):
399            a = torch.randn(M, K)
400            b = torch.randn(K, N)
401            ms = benchmarker.benchmark_gpu(lambda: f(a, b))
402            print(f"MxKxN {M}x{K}x{N} {f.__name__}: {ms:.3f}ms")
403
404    @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled")
405    def test_padmm(self):
406        """
407        Latency between origional matmul and padded matmul: 2.717 v.s. 2.356
408        """
409        mat1_pad = torch.randn(8192, 30522, dtype=torch.float16)
410        mat2_pad = torch.randn(30522, 768, dtype=torch.float16)
411
412        def f():
413            return mat1_pad @ mat2_pad
414
415        def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor:
416            pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
417            return torch.cat([x, pad], dim=dim)
418
419        @torch.compile(fullgraph=True, options={"triton.cudagraphs": False})
420        def g():
421            mat1 = mat1_pad
422            mat2 = mat2_pad
423            mat1 = pad_dim(mat1, 6, 1)
424            mat2 = pad_dim(mat2, 6, 0)
425            return torch.ops.aten.mm(mat1, mat2)
426
427        ori_time = benchmarker.benchmark_gpu(f)
428        pad_time = benchmarker.benchmark_gpu(g)
429
430        print(
431            f"Latency between origional matmul and padded matmul: {ori_time:.3f} v.s. {pad_time:.3f}"
432        )
433        self.do_profiling(f, g, "No MM Padding", "With mm padding")
434
435    @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled")
436    def test_matmul(self):
437        """
438        Latency with good and bad shapes: 1.705 v.s. 2.625
439        """
440        x_good_shape = torch.randn(8192, 30528, dtype=torch.float16)
441        weight_good_shape = torch.randn(30528, 768, dtype=torch.float16)
442        out_good_shape = torch.randn(8192, 768, dtype=torch.float16)
443
444        # Using stride (30522, 1) does not make a difference here.
445        x_bad_shape = rand_strided(
446            (8192, 30522), (30528, 1), device="cuda", dtype=torch.float16
447        )
448        weight_bad_shape = torch.randn(30522, 768, dtype=torch.float16)
449        out_bad_shape = torch.randn(8192, 768, dtype=torch.float16)
450
451        def f(x, weight, out):
452            torch.mm(x, weight, out=out)
453            return out
454
455        f1 = torch.compile(
456            functools.partial(f, x_good_shape, weight_good_shape, out_good_shape)
457        )
458        f2 = torch.compile(
459            functools.partial(f, x_bad_shape, weight_bad_shape, out_bad_shape)
460        )
461        latency_good_shape = benchmarker.benchmark_gpu(f1)
462        latency_bad_shape = benchmarker.benchmark_gpu(f2)
463        print(
464            f"Latency with good and bad shapes: {latency_good_shape:.3f} v.s. {latency_bad_shape:.3f}"
465        )
466        self.do_profiling(f1, f2)
467
468    @serialTest()
469    def test_nobias_LinearAndSoftmax_codegen(self):
470        self.test_LinearAndSoftmax_codegen(bias=False)
471
472    def test_LinearAndSoftmax_codegen(self, bias=True):
473        m_bad_shape = LinearAndSoftmax(vocab_size=30523, bias=bias)
474        inputs_bad_shape = m_bad_shape.get_example_inputs()
475        m_bad_shape_opt = torch.compile(copy.deepcopy(m_bad_shape))
476
477        _, wrapper_codes = run_and_get_code(
478            forward_and_backward_pass, m_bad_shape_opt, inputs_bad_shape
479        )
480        forward_and_backward_pass(m_bad_shape, inputs_bad_shape)
481        self.assertEqual(
482            m_bad_shape.linear.weight.grad, m_bad_shape_opt.linear.weight.grad
483        )
484        self.assertTrue(len(wrapper_codes) == 2)  # one for forward and oen for backward
485        forward_wrapper = wrapper_codes[0]
486
487        # make sure the load for softmax is aligned
488        self.assertTrue(
489            "tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper,
490            f"forward_wrapper: {forward_wrapper}",
491        )
492
493        if DO_PERF_TEST:
494            latency = benchmarker.benchmark_gpu(
495                lambda: forward_and_backward_pass(m_bad_shape_opt, inputs_bad_shape)
496            )
497            print(f"latency: {latency:.3f}ms")
498
499    @config.patch(pattern_matcher=False)
500    def test_attention(self):
501        batch_size, seq_len, num_heads, hidden_size = 1, 4, 1, 16
502        inv_scale = (num_heads / hidden_size) ** 0.5
503
504        class Attention(nn.Module):
505            def __init__(self) -> None:
506                super().__init__()
507                self.query = nn.Linear(hidden_size, hidden_size)
508                self.key = nn.Linear(hidden_size, hidden_size)
509                self.value = nn.Linear(hidden_size, hidden_size)
510
511            @staticmethod
512            def reshape(x):
513                return x.view(batch_size, seq_len, num_heads, -1).permute(0, 2, 1, 3)
514
515            @staticmethod
516            def cancel_reshape(x):
517                return x.permute(0, 2, 1, 3).view(batch_size, seq_len, hidden_size)
518
519            def forward(self, x):
520                query, key, value = self.query(x), self.key(x), self.value(x)
521                weights = (
522                    torch.matmul(
523                        self.reshape(query), self.reshape(key).permute(0, 1, 3, 2)
524                    )
525                    * inv_scale
526                ).softmax(dim=-1)
527                return self.cancel_reshape(torch.matmul(weights, self.reshape(value)))
528
529        attn = Attention()
530        x = torch.randn(batch_size, seq_len, hidden_size)
531
532        self.common_numeric_check(attn, x)
533
534    def test_view(self):
535        def f(x):
536            return x.view(3, 3, 3)
537
538        x = torch.randn(3, 9)
539        self.common_numeric_check(f, x)
540
541    def test_pad_strides(self):
542        """
543        Note that dim0's stride is also padded even though its previous value
544        is already multiple of 16. The reason is we padded dim1's stride.
545        We have to correspondingly increase the stride for dim0.
546        """
547        sizes = [2, 16, 2047]
548        in_strides = [2047 * 16, 2047, 1]
549        out_strides = list(ir.Layout._pad_strides(in_strides, sizes, torch.float32))
550        expected_strides = [2048 * 16, 2048, 1]
551        self.assertEqual(
552            expected_strides, out_strides, f"{expected_strides} v.s. {out_strides}"
553        )
554
555    def test_pad_strides_skip(self):
556        """
557        The padding is skipped to avoid too much memory overhead.
558        """
559        sizes = [2, 32, 127]
560        in_strides = [4064, 127, 1]
561        out_strides = list(ir.Layout._pad_strides(in_strides, sizes, torch.float32))
562        expected_strides = [4064, 127, 1]
563        self.assertEqual(
564            expected_strides, out_strides, f"{expected_strides} v.s. {out_strides}"
565        )
566
567    def test_pad_3d_tensor(self):
568        """
569        Constructing this test case guided by the fact that we don't pad
570        placeholder or user visible output's strides.
571
572        Add a matmul in the beginning and end so we can pad strides for
573        intermediate tensors.
574        """
575
576        def f(x, y):
577            x = torch.matmul(x, y)
578            x = x + 1
579            return torch.matmul(x, y)
580
581        x = torch.randn(2, 16, 2047)
582        y = torch.randn(2047, 2047)
583        self.common_numeric_check(f, x, y, tol=1e-2)
584        self.assertTrue(metrics.num_comprehensive_padding > 0)
585
586    def test_conv(self):
587        """
588        Padding the input for convolution may cause extra copy kernel being called.
589        Check this example trace: https://gist.github.com/shunting314/ce45398f7d51a63ce05fc8d411faddb3
590        """
591        x_shape = (1, 128, 640, 959)
592        x1 = torch.randn(*x_shape)
593
594        padded_stride = ir.Layout._pad_strides(x1.stride(), x1.shape, torch.float32)
595        x2 = rand_strided(x_shape, padded_stride, device="cuda")
596        x2.copy_(x1)
597
598        weight = torch.randn(64, 128, 3, 3)
599
600        def fun(x, weight):
601            return torch.convolution(
602                x,
603                weight,
604                stride=(1, 1),
605                padding=(1, 1),
606                dilation=(1, 1),
607                transposed=False,
608                output_padding=(0, 0),
609                groups=1,
610                bias=None,
611            )
612
613        ref = fun(x1, weight)
614        act = fun(x2, weight)
615        self.check_close(ref, act)
616        if DO_PERF_TEST:
617            latency_with_padding = benchmarker.benchmark_gpu(lambda: fun(x2, weight))
618            latency_without_padding = benchmarker.benchmark_gpu(lambda: fun(x1, weight))
619            print(
620                f"Latency with and without padding: {latency_with_padding:.3f} v.s. {latency_without_padding:.3f}"
621            )
622
623            self.do_profiling(lambda: fun(x2, weight), lambda: fun(x1, weight))
624
625    @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled")
626    def test_cat(self):
627        """
628        Compare the perf between aten cat and compiled cat.
629
630        Latency between eager and compiled: 1.596 v.s. 0.601
631
632        Eager cat can be 2.66x slower than inductor kernel.
633        """
634        x = torch.randn(8192, 30522, dtype=torch.float16)
635
636        def f(x):
637            pad = x.new_zeros(x.size(0), 6)
638            return torch.cat([x, pad], dim=1)
639
640        # disable cudagraphs since cudagraphs need copy the input which
641        # distort the latency a lot! (double the latency here for compiled
642        # version)
643        with config.patch("triton.cudagraphs", False):
644            opt_f = torch.compile(f)
645            opt_f(x)
646        eager_time = benchmarker.benchmark_gpu(lambda: f(x))
647        opt_time = benchmarker.benchmark_gpu(lambda: opt_f(x))
648        print(
649            f"Latency between eager and compiled: {eager_time:.3f} v.s. {opt_time:.3f}"
650        )
651        self.do_profiling(lambda: f(x), lambda: opt_f(x), "Eager Cat", "Compiled Cat")
652
653    def test_pad_channels_last(self):
654        t = torch.randn(2, 3, 5, 1025)
655        in_strides = t.stride()
656        out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32)
657        self.assertTrue(in_strides != out_strides)
658
659        t = t.to(memory_format=torch.channels_last)
660        in_strides = t.stride()
661        out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32)
662        self.assertTrue(in_strides == out_strides)
663
664    @parametrize("alignment_bytes", (32, 128))
665    @parametrize("shape", [(21, 19), (3, 5, 71)])
666    @parametrize("dtype", (torch.float16, torch.float32))
667    def test_pad_outputs(
668        self, dtype: torch.dtype, shape: Tuple[int], alignment_bytes: int
669    ):
670        """
671        Tests padding output tensors to a specific alignment.
672        This is enabled by a config flag.
673        """
674        func = torch.add
675        inputs = tuple(torch.randn(*shape, dtype=dtype) for input_idx in range(2))
676
677        # Compile and run
678        with config.patch(
679            {
680                "comprehensive_padding": True,
681                "padding_alignment_bytes": alignment_bytes,
682                "padding_stride_threshold": 0,
683                "pad_outputs": True,
684            }
685        ):
686            compiled_func = torch.compile(func)
687            compiled_out = compiled_func(*inputs)
688
689        # Check numerics
690        eager_out = func(*inputs)
691        self.check_close(eager_out, compiled_out)
692
693        # Compute the expected padding
694        element_size = torch.tensor([], dtype=dtype).element_size()
695        self.assertGreater(alignment_bytes, element_size)
696        self.assertEqual(alignment_bytes % element_size, 0)
697        alignment_elements = alignment_bytes // element_size
698        contiguous_stride = inputs[0].stride()
699        expected_stride = [1]
700        for dim in reversed(shape[1:]):
701            slice_size = dim * expected_stride[0]
702            new_stride = alignment_elements * ceildiv(slice_size, alignment_elements)
703            expected_stride.insert(0, new_stride)
704        expected_stride = tuple(expected_stride)
705        self.assertNotEqual(expected_stride, contiguous_stride)
706
707        # Check strides
708        self.assertFalse(compiled_out.is_contiguous())
709        self.assertEqual(compiled_out.stride(), expected_stride)
710
711
712if __name__ == "__main__":
713    if HAS_CUDA:
714        run_tests()
715