xref: /aosp_15_r20/external/pytorch/test/inductor/test_cpu_repro.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: cpu inductor"]
2import contextlib
3import copy
4import functools
5import itertools
6import math
7import platform
8import sys
9import unittest
10from typing import Callable
11from unittest.mock import patch
12
13import numpy as np
14import sympy
15
16import torch
17from torch import nn
18from torch._C import FileCheck
19from torch._dynamo.testing import rand_strided
20from torch._dynamo.utils import same
21from torch._inductor import codecache, config, metrics
22from torch._inductor.codegen.common import OptimizationContext
23from torch._inductor.codegen.cpp import (
24    CppOverrides,
25    CppVecKernelChecker,
26    CppVecOverrides,
27)
28from torch._inductor.compile_fx import (
29    compile_fx,
30    compile_fx_inner,
31    complex_memory_overlap,
32)
33from torch._inductor.graph import GraphLowering
34from torch._inductor.ir import InterpreterShim
35from torch._inductor.utils import timed
36from torch._inductor.virtualized import V
37from torch.fx.experimental.proxy_tensor import make_fx
38from torch.nn import functional as F
39from torch.testing._internal.common_utils import (
40    instantiate_parametrized_tests,
41    IS_MACOS,
42    parametrize,
43    slowTest,
44)
45from torch.utils._python_dispatch import TorchDispatchMode
46
47try:
48    try:
49        from . import test_torchinductor
50    except ImportError:
51        import test_torchinductor
52except unittest.SkipTest:
53    if __name__ == "__main__":
54        sys.exit(0)
55    raise
56
57
58vec_dtypes = test_torchinductor.vec_dtypes
59_lowp_fp_dtypes = (
60    torch.bfloat16,
61    torch.float16,
62)
63run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
64TestCase = test_torchinductor.TestCase
65aten = torch.ops.aten
66check_model = test_torchinductor.check_model
67
68requires_vectorization = unittest.skipUnless(
69    codecache.valid_vec_isa_list(), "Does not support vectorization"
70)
71
72
73def check_metrics_vec_kernel_count(num_expected_vec_kernels):
74    if codecache.valid_vec_isa_list():
75        assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
76
77
78@contextlib.contextmanager
79def set_num_threads(num_threads):
80    orig_num_threads = torch.get_num_threads()
81    torch.set_num_threads(num_threads)
82    yield
83    torch.set_num_threads(orig_num_threads)
84
85
86class LstmModule(torch.nn.Module):
87    def __init__(
88        self,
89        input_size,
90        hidden_size,
91        num_layers,
92        bias=True,
93        bidirectional=False,
94        batch_first=False,
95    ):
96        super().__init__()
97        self.lstm = torch.nn.LSTM(
98            input_size=input_size,
99            hidden_size=hidden_size,
100            num_layers=num_layers,
101            bias=bias,
102            bidirectional=bidirectional,
103            batch_first=batch_first,
104        )
105
106    def forward(self, x, h=None):
107        x, h = self.lstm(x, h)
108        return x, h
109
110
111@instantiate_parametrized_tests
112class CPUReproTests(TestCase):
113    common = check_model
114
115    def test_conv_stride_constraints(self):
116        for fmt in [torch.contiguous_format, torch.channels_last]:
117            # TorchDispatch doesn't work in our cuda invocation for some reason
118            m = torch.nn.Conv2d(5, 6, [3, 3])
119
120            def fn(inp, weight):
121                return (
122                    F.conv2d(
123                        inp, weight, None, m.stride, m.padding, m.dilation, m.groups
124                    ),
125                )
126
127            inp = torch.randn([2, 5, 16, 16])
128            inps = [inp, m.weight.to(memory_format=fmt)]
129            fn_fx = make_fx(fn)(*inps)
130            fn_compiled = compile_fx_inner(fn_fx, inps)
131            test_self = self
132            conv_seen = False
133
134            class RecordFunctions(TorchDispatchMode):
135                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
136                    kwargs = kwargs if kwargs else {}
137                    if func == torch.ops.aten.convolution.default:
138                        # For CPU and mkldnn enable, we always using channles last
139                        nonlocal fmt
140                        if (
141                            torch.backends.mkldnn.enabled
142                            and torch.backends.mkldnn.is_available()
143                        ):
144                            fmt = torch.channels_last
145                        test_self.assertTrue(args[0].is_contiguous(memory_format=fmt))
146                        test_self.assertTrue(args[1].is_contiguous(memory_format=fmt))
147                        nonlocal conv_seen
148                        conv_seen = True
149
150                    return func(*args, **kwargs)
151
152            with RecordFunctions():
153                out = fn_compiled(inps)
154
155            self.assertTrue(conv_seen)
156
157    @patch("torch.cuda.is_available", lambda: False)
158    def test_conv2d_bn_mixed_dtype(self):
159        class Model(torch.nn.Module):
160            def __init__(self):
161                super().__init__()
162                self.conv = torch.nn.Conv2d(
163                    3,
164                    16,
165                    kernel_size=3,
166                    stride=1,
167                    padding=1,
168                    bias=False,
169                    dtype=torch.bfloat16,
170                )
171                self.bn = torch.nn.BatchNorm2d(
172                    16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
173                )
174
175            def forward(self, x):
176                x = self.conv(x)
177                x = self.bn(x)
178                return x
179
180        v = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
181        mod = Model().eval()
182        with torch.no_grad():
183            self.common(
184                mod,
185                (v,),
186            )
187
188    @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
189    @patch("torch.cuda.is_available", lambda: False)
190    def test_conv2d_packed(self):
191        options = itertools.product([[3, 56, 56]], [True, False], [0, (0,)])
192        for x_shape, mode_train, padding in options:
193            mod = torch.nn.Sequential(
194                torch.nn.Conv2d(3, 64, 3, 3, padding=padding)
195            ).train(mode=mode_train)
196            v = torch.randn(x_shape, dtype=torch.float32)
197
198            with torch.no_grad():
199                self.common(
200                    mod,
201                    (v,),
202                )
203
204    @patch("torch.cuda.is_available", lambda: False)
205    def test_conv2d_autocast(self):
206        v = torch.randn(1, 3, 28, 18, dtype=torch.float32)
207        mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval()
208        with torch.no_grad(), torch.cpu.amp.autocast():
209            self.common(
210                mod,
211                (v,),
212            )
213
214    @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
215    @patch("torch.cuda.is_available", lambda: False)
216    def test_unsupported_conv_transpose(self):
217        class Model(torch.nn.Module):
218            def __init__(self):
219                super().__init__()
220                self.conv_transpose = torch.nn.ConvTranspose2d(
221                    3, 6, 3, stride=1, padding=1, output_padding=1
222                )
223
224            def forward(self, input_tensor):
225                x = self.conv_transpose(input_tensor)
226                output = torch.tanh(x)
227                return output
228
229        input = torch.randn(1, 3, 28, 28)
230        m = Model().eval()
231
232        with torch.no_grad():
233            compiled_m = torch.compile(m)
234            with self.assertRaisesRegex(
235                RuntimeError,
236                "output padding must be smaller than either stride or dilation",
237            ):
238                compiled_m(input)
239
240    @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
241    @patch("torch.cuda.is_available", lambda: False)
242    def test_conv_used_from_multiple_places(self):
243        class M(torch.nn.Module):
244            def __init__(self, conv_in_channel, conv_out_channel) -> None:
245                super().__init__()
246                self.conv = torch.nn.Conv2d(conv_in_channel, conv_out_channel, (3, 3))
247
248            def forward(self, x):
249                res = self.conv(x)
250                res = F.relu(res)
251                res = self.conv(res)
252                return res
253
254        with torch.no_grad():
255            mod = M(3, 3).eval()
256            x = torch.randn(1, 3, 224, 224)
257            self.common(
258                mod,
259                (x,),
260            )
261
262    @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
263    @patch("torch.cuda.is_available", lambda: False)
264    def test_linear_used_from_multiple_places(self):
265        class M(torch.nn.Module):
266            def __init__(self, in_channel, out_channel) -> None:
267                super().__init__()
268                self.linear = torch.nn.Linear(in_channel, out_channel)
269
270            def forward(self, x):
271                res = self.linear(x)
272                res = F.relu(res)
273                res = self.linear(res)
274                return res
275
276        dtypes = []
277        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
278            dtypes.append(torch.bfloat16)
279        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
280            dtypes.append(torch.float16)
281        for dtype in dtypes:
282            with torch.no_grad():
283                m = M(224, 224).to(dtype).eval()
284                m_opt = torch.compile(m)
285                x = torch.randn(224, 224, dtype=dtype)
286                m_opt(x)
287                self.assertEqual(m(x), m_opt(x))
288
289    @config.patch(implicit_fallbacks=True)
290    def test_multihead_attention_cpu(self):
291        def fn(
292            q,
293            k,
294            v,
295            embed_dim,
296            num_heads,
297            qkv_weight,
298            qkv_bias,
299            proj_weight,
300            proj_bias,
301            mask,
302            need_weights,
303        ):
304            return torch._native_multi_head_attention(
305                q,
306                k,
307                v,
308                embed_dim,
309                num_heads,
310                qkv_weight,
311                qkv_bias,
312                proj_weight,
313                proj_bias,
314                mask,
315                need_weights,
316            )
317
318        B = 1
319        T = 3
320        embed_dim = 6
321        num_heads = 2
322        q = torch.randn([B, T, embed_dim])
323        k = torch.randn([B, T, embed_dim])
324        v = torch.randn([B, T, embed_dim])
325        qkv_weight = torch.randn([3 * embed_dim, embed_dim])
326        qkv_bias = torch.randn([3 * embed_dim])
327        proj_weight = torch.randn([3 * embed_dim, embed_dim])
328        proj_bias = torch.randn([3 * embed_dim])
329        mask = None
330        need_weights = False
331
332        inps = [
333            q,
334            k,
335            v,
336            embed_dim,
337            num_heads,
338            qkv_weight,
339            qkv_bias,
340            proj_weight,
341            proj_bias,
342            mask,
343            need_weights,
344        ]
345        self.common(fn, inps)
346
347    @config.patch(freezing=True)
348    def test_module_buffer_mutation(self):
349        class Model(torch.nn.Module):
350            def __init__(self):
351                super().__init__()
352                self.register_buffer("foo", torch.rand((3, 10)))
353
354            def forward(self, x):
355                lx = [x, x.clone(), x.clone()]
356                y = []
357                for i in range(3):
358                    y.append(lx[i] + self.foo[i])
359                return torch.cat(y, 1)
360
361        with torch.no_grad():
362            example_inputs = (torch.rand(1, 10),)
363            self.common(Model(), example_inputs)
364
365    @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
366    @patch("torch.cuda.is_available", lambda: False)
367    def test_linear_packed(self):
368        dtypes = []
369        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
370            dtypes.append(torch.bfloat16)
371        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
372            dtypes.append(torch.float16)
373        options = itertools.product(
374            [[2, 3, 10], [2, 10], [10], [2, 0]], [3, 0], [True, False], dtypes
375        )
376        for input_shape, out_dim, bias, dtype in options:
377            mod = torch.nn.Sequential(
378                torch.nn.Linear(input_shape[-1], out_dim, bias=bias)
379            ).eval()
380
381            v = torch.randn(input_shape)
382            with torch.no_grad():
383                self.common(
384                    mod.to(dtype),
385                    (v.to(dtype),),
386                )
387
388    @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
389    @patch("torch.cuda.is_available", lambda: False)
390    def test_conv_transpose2d_packed_cpu(self):
391        options = itertools.product([[1, 3, 28, 28], [3, 28, 28]], [0, (0,)])
392        for x_shape, padding in options:
393            mod = torch.nn.Sequential(
394                torch.nn.ConvTranspose2d(3, 64, 3, 3, padding=padding)
395            ).eval()
396            v = torch.randn(x_shape, dtype=torch.float32)
397            with torch.no_grad():
398                self.common(
399                    mod,
400                    (v,),
401                )
402
403    @config.patch(freezing=True)
404    @unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled")
405    @torch._dynamo.config.patch(dynamic_shapes=True)
406    @torch._dynamo.config.patch(assume_static_by_default=False)
407    def test_conv_in_channel_1_dynamic_shapes(self):
408        class M(torch.nn.Module):
409            def __init__(self, in_channel, out_channel) -> None:
410                super().__init__()
411                self.conv = torch.nn.Conv2d(in_channel, out_channel, 3)
412
413            def forward(self, x):
414                res = self.conv(x)
415                res = F.relu(res)
416                return res
417
418        # test the case where the channels dim of the input is 1
419        # Reproducer from the maml_omniglot model in Torchbench
420        in_channel = 1
421        out_channel = 3
422        amp_enabled_configs = [False]
423        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
424            # When amp is enabled here, the input to Conv is a FlexibleLayout.
425            # While it's disabled, the input is a FixedLayout.
426            amp_enabled_configs.append(True)
427        for amp_enabled in amp_enabled_configs:
428            mod = M(in_channel, out_channel).eval()
429            v = torch.randn(5, in_channel, 15, 15)
430            with torch.no_grad(), torch.cpu.amp.autocast(enabled=amp_enabled):
431                self.common(
432                    mod,
433                    (v,),
434                )
435
436    @unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled")
437    @patch("torch.cuda.is_available", lambda: False)
438    @torch._dynamo.config.patch(dynamic_shapes=True)
439    @torch._dynamo.config.patch(assume_static_by_default=False)
440    @torch._dynamo.config.patch(allow_rnn=True)
441    @config.patch(freezing=True)
442    def _test_lstm_packed(self, params_dict, change_input_sizes=False):
443        from torch._dynamo.utils import counters
444
445        for (
446            unbatched,
447            input_size,
448            hidden_size,
449            num_layers,
450            bidirectional,
451            bias,
452            empty_state,
453            batch_first,
454            batch_size,
455            seq_len,
456        ) in itertools.product(*list(params_dict.values())):
457            dtypes = [torch.float]
458            if torch.ops.mkldnn._is_mkldnn_bf16_supported():
459                dtypes.append(torch.bfloat16)
460            if torch.ops.mkldnn._is_mkldnn_fp16_supported():
461                dtypes.append(torch.float16)
462            for dtype in dtypes:
463                counters.clear()
464                num_directions = 2 if bidirectional else 1
465
466                seq_len_var = seq_len + 3
467                if unbatched:
468                    v = torch.randn(seq_len, input_size)
469                    v_var = torch.randn(seq_len_var, input_size)
470                    h = torch.randn(num_layers * num_directions, hidden_size)
471                    c = torch.randn(num_layers * num_directions, hidden_size)
472                else:
473                    if batch_first:
474                        v = torch.randn(batch_size, seq_len, input_size)
475                        v_var = torch.randn(batch_size, seq_len_var, input_size)
476                    else:
477                        v = torch.randn(seq_len, batch_size, input_size)
478                        v_var = torch.randn(seq_len_var, batch_size, input_size)
479                    h = torch.randn(
480                        num_layers * num_directions, batch_size, hidden_size
481                    )
482                    c = torch.randn(
483                        num_layers * num_directions, batch_size, hidden_size
484                    )
485
486                mod = LstmModule(
487                    input_size,
488                    hidden_size,
489                    num_layers,
490                    bias,
491                    bidirectional,
492                    batch_first,
493                ).eval()
494                maybe_autocast = (
495                    torch.cpu.amp.autocast()
496                    if dtype == torch.bfloat16
497                    else contextlib.nullcontext()
498                )
499
500                with torch.no_grad(), maybe_autocast:
501                    inps = [v]
502                    if not empty_state:
503                        inps.append((h, c))
504
505                    fn_opt = torch._dynamo.optimize("inductor")(mod)
506                    _, code = run_and_get_cpp_code(fn_opt, *inps)
507
508                    # Check that _flat_weights are not functional_tensor, otherwise
509                    # deepcopy will fail during recompilation.
510                    fn_opt_copy = copy.deepcopy(fn_opt)
511                    _flat_weights = fn_opt_copy.lstm._flat_weights
512                    for _flat_weight in _flat_weights:
513                        self.assertFalse(torch._is_functional_tensor(_flat_weight))
514
515                    self.assertTrue("aten.mkldnn_rnn_layer" in code)
516                    self.assertEqual(fn_opt(*inps), mod(*inps))
517                    self.assertEqual(
518                        counters["inductor"]["pattern_matcher_count"],
519                        num_layers * num_directions
520                        + 2,  # num of mkldnn_rnn_layer call + 2 view call on the concatenated hy, cy.
521                    )
522
523                    # Change input sizes
524                    if change_input_sizes:
525                        inps_var = [v_var]
526                        self.assertEqual(fn_opt(*inps_var), mod(*inps_var))
527
528    @slowTest
529    def test_lstm_packed(self):
530        params_dict = {
531            "unbatched": [True, False],
532            "input_size": [1, 2],
533            "hidden_size": [2],
534            "num_layers": [1, 2],
535            "bidirectional": [False, True],
536            "bias": [False, True],
537            "empty_state": [False, True],
538            "batch_first": [True, False],
539            "batch_size": [1, 2],
540            "seq_len": [1, 2],
541        }
542        self._test_lstm_packed(params_dict)
543
544    def test_lstm_packed_change_input_sizes_cpu(self):
545        params_dict = {
546            "unbatched": [False],
547            "input_size": [2],
548            "hidden_size": [5],
549            "num_layers": [3],
550            "bidirectional": [True],
551            "bias": [True],
552            "empty_state": [False],
553            "batch_first": [False],
554            "batch_size": [2],
555            "seq_len": [3],
556        }
557        self._test_lstm_packed(params_dict, change_input_sizes=True)
558
559    @torch._dynamo.config.patch(dynamic_shapes=True)
560    @torch._dynamo.config.patch(assume_static_by_default=False)
561    @torch._dynamo.config.patch(allow_rnn=True)
562    def test_pack_padded_sequence_lstm(self):
563        embedding_dim = 12
564        hidden_dim = 10
565        batch_size = 24
566        num_layers = 1
567        bidirectional = True
568        num_direc = 2
569        max_lens = 96
570
571        sent = torch.randn(batch_size, max_lens, embedding_dim)
572        hid_0 = torch.rand(num_layers * num_direc, batch_size, hidden_dim)
573        hid_1 = torch.randn(num_layers * num_direc, batch_size, hidden_dim)
574
575        sent_lens = torch.Tensor(
576            [1, 2, 3, 4, 5, 1, 3, 2, 96, 5, 3, 1, 1, 2, 1, 2, 3, 6, 1, 2, 4, 6, 2, 1]
577        )
578
579        assert sent_lens.shape[0] == batch_size
580        assert sent_lens.max().item() == max_lens
581
582        hidden_0 = hid_0.clone().requires_grad_(False)
583        hidden_1 = hid_1.clone().requires_grad_(False)
584        embeds = torch.nn.utils.rnn.pack_padded_sequence(
585            sent, sent_lens, batch_first=True, enforce_sorted=False
586        )
587
588        mod = LstmModule(
589            embedding_dim,
590            hidden_dim,
591            num_layers=num_layers,
592            bias=True,
593            bidirectional=bidirectional,
594            batch_first=True,
595        ).eval()
596
597        with torch.no_grad():
598            inps = [embeds, (hidden_0, hidden_1)]
599            fn_opt = torch._dynamo.optimize("inductor")(mod)
600            _, code = run_and_get_cpp_code(fn_opt, *inps)
601            # This case is unsupported
602            self.assertFalse("torch.ops.mkldnn._lstm" in code)
603            self.assertEqual(fn_opt(*inps), mod(*inps))
604
605    @patch("torch.cuda.is_available", lambda: False)
606    def test_conv_transpose2d_has_output_size_input(self):
607        # https://github.com/pytorch/pytorch/issues/100344.
608        class M(torch.nn.Module):
609            def __init__(self) -> None:
610                super().__init__()
611                self.conv_transpose = torch.nn.ConvTranspose2d(
612                    in_channels=3, out_channels=1, kernel_size=3, stride=1, padding=1
613                )
614
615            def forward(self, x):
616                return self.conv_transpose(x, output_size=(10, 10))
617
618        mod = M().eval()
619        v = torch.randn(1, 3, 10, 10, dtype=torch.float32)
620        with torch.no_grad():
621            self.common(
622                mod,
623                (v,),
624            )
625
626    def test_pad_with_nan_value(self):
627        # https://github.com/pytorch/pytorch/issues/100988.
628        class Model(torch.nn.Module):
629            def forward(self, x):
630                x = F.pad(x, (1, 1, 1, 1), value=float("nan"))
631                return x
632
633        mod = Model().eval()
634        v = torch.randn(1, 3, 10, 10, dtype=torch.float32)
635        with torch.no_grad():
636            self.common(
637                mod,
638                (v,),
639            )
640
641    def test_masked_fill_with_inf_or_nan_value(self):
642        def fn(value, mask):
643            y1 = torch.masked_fill(value, mask, float("inf"))
644            y2 = torch.masked_fill(value, mask, float("-inf"))
645            y3 = torch.masked_fill(value, mask, float("nan"))
646            return y1, y2, y3
647
648        value = torch.randn((2, 17))
649        mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool)
650        with torch.no_grad():
651            self.common(
652                fn,
653                (value, mask),
654            )
655
656    def test_relu_with_inf_value(self):
657        # https://github.com/pytorch/pytorch/issues/117544.
658
659        def fn(out):
660            out = torch.sinh(input=out)
661            out = torch.relu(input=out)
662            return out
663
664        x = torch.Tensor([-572373.5000, 755109.1250, 330995.5625])
665        with torch.no_grad():
666            self.common(
667                fn,
668                (x,),
669            )
670
671    def test_acosh_with_negative_large_input(self):
672        # https://github.com/pytorch/pytorch/issues/118267.
673
674        def fn(input):
675            out = torch.acosh(input)
676            return out
677
678        x = torch.Tensor(
679            [
680                [
681                    -8493.9854,
682                    431654.1250,
683                    71741.5859,
684                    608234.5000,
685                    -103814.7500,
686                    -699397.0000,
687                    -910685.8125,
688                    -832737.1875,
689                    875343.5000,
690                ]
691            ]
692        ).repeat(3, 9)
693
694        for dtype in [torch.float32, torch.bfloat16, torch.double]:
695            with torch.no_grad():
696                torch._dynamo.reset()
697                metrics.reset()
698                _x = x.to(dtype)
699                self.common(
700                    fn,
701                    (_x,),
702                )
703
704    @config.patch(implicit_fallbacks=True)
705    def test_repeat_interleave(self):
706        def fn(y):
707            return torch.repeat_interleave(y, 2, output_size=8)
708
709        a = torch.tensor([[1, 2], [3, 4]])
710        self.common(
711            fn,
712            (a,),
713        )
714
715    def test_inplace_squeeze_needed(self):
716        mod = torch.nn.Sequential(
717            torch.nn.Linear(10, 10),
718            torch.nn.LayerNorm(10),
719            torch.nn.ReLU(),
720        ).eval()
721
722        def fn(x):
723            return mod(x)
724
725        v = torch.randn(10)
726        # TODO: OMP parallel reduction order is not deterministic.
727        # Hence, the accurarcy might vary up and down. For short term,
728        # we increase the tolerance and will fix it later by using
729        # aten parallel.
730        self.common(fn, (v,), atol=5e-1, rtol=5e-1)
731
732    def test_cat_mul(self):
733        # https://github.com/pytorch/pytorch/issues/93365
734        def fn(p0, p1):
735            y1 = torch.cat([p0, p1], dim=0)
736            y2 = torch.mul(y1, y1)
737            return y1, y2
738
739        p0 = torch.randn(3, 4)
740        p1 = torch.randn(3, 4)
741        self.common(fn, (p0, p1))
742
743    def test_pow_cos(self):
744        # https://github.com/pytorch/pytorch/issues/98149
745        def fn(x):
746            t = x.pow(5)
747            return torch.cos(t)
748
749        x = torch.tensor([4], dtype=torch.uint8)
750        self.common(fn, (x,))
751
752    def test_reduce_with_masked(self):
753        # https://github.com/pytorch/pytorch/issues/96484
754        def fn(a, b):
755            a = torch.nn.functional.pad(a, (0, -1))
756            c = a + b
757            return c.min(0).values
758
759        a = torch.randn([2])
760        b = torch.randn([2])
761        self.common(fn, (a, b))
762
763    def test_scalar_sign_with_min(self):
764        # https://github.com/pytorch/pytorch/issues/101340
765        def fn(a):
766            t1 = torch.tanh(a)
767            t2 = torch.sign(t1)
768            return torch.min(t1, t2)
769
770        a = torch.randn(1, 3)
771        self.common(fn, (a,))
772
773    def test_index_propagation_issue_102065(self):
774        def fn(x):
775            x = torch.arange(x.numel())
776            return (x.unsqueeze(0) - x.unsqueeze(1)) ** 2
777
778        self.common(
779            fn,
780            (torch.randn(8),),
781        )
782
783    def test_ModularIndexing_range_issue_103133(self):
784        def fn(q, k):
785            einsum = torch.einsum("bcxd,bcyd->bcxy", (q, k))
786            constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
787                einsum, [0, 0, 0, 1], 0.0
788            )
789            view = torch.ops.aten.view.default(constant_pad_nd, [12, 1, 512, 513])
790            y = view.new_zeros((12, 2, 256, 513))
791            y[:, :-1, :, 256:] = view[:, :, :256, :257]
792            return y
793
794        self.common(
795            fn,
796            (
797                torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)),
798                torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)),
799            ),
800        )
801
802    @patch("torch.cuda.is_available", lambda: False)
803    def test_max_reduction_lowp_fp(self):
804        def fn(x):
805            return torch.ops.aten.max(x, 1, keepdim=True)[0].float()
806
807        for dtype in _lowp_fp_dtypes:
808            self.common(
809                fn,
810                (torch.randn(1, 32, 4, 4).to(dtype),),
811            )
812
813    @patch("torch.cuda.is_available", lambda: False)
814    def test_vec_transpose_lowp_fp(self):
815        for dtype in _lowp_fp_dtypes:
816
817            def fn(x):
818                return x.to(memory_format=torch.channels_last).to(dtype)
819
820            self.common(
821                fn,
822                (torch.randn(2, 3, 4, 4),),
823            )
824
825    def test_load_inf_bf16(self):
826        def fn1(x):
827            return torch.where(x > 0, x, math.inf)
828
829        def fn2(x):
830            return torch.where(x > 0, x, -math.inf)
831
832        for fn in [fn1, fn2]:
833            self.common(
834                fn,
835                (torch.randn(1, 3, 16, 16),),
836            )
837
838    @patch("torch.cuda.is_available", lambda: False)
839    def test_fp32_load_with_to_lowp_fp(self):
840        # From llama model.
841        class Model(torch.nn.Module):
842            def __init__(self):
843                super().__init__()
844                self.cache_k = torch.zeros(8, 4, 2, 2)
845
846            def forward(self, x, xk):
847                bsz, seqlen, _ = x.shape
848                self.cache_k = self.cache_k.to(x)
849                self.cache_k[:bsz, 1 : 1 + seqlen] = xk
850                return self.cache_k
851
852        for dtype in _lowp_fp_dtypes:
853            ref_model = Model().eval()
854            opt_model = torch.compile()(Model().eval())
855            x = torch.randn(4, 2, 2).to(dtype)
856            xk = torch.randn(4, 2, 2, 2).to(dtype)
857            self.assertEqual(opt_model(x, xk), ref_model(x, xk))
858
859    @requires_vectorization
860    @patch("torch.cuda.is_available", lambda: False)
861    def test_sigmoid_with_reduction(self):
862        def fn(x):
863            x = torch.ops.aten.sigmoid.default(x)
864            return torch.ops.aten.mean.dim(x, [-1, -2], True)
865
866        x = torch.randn((1, 8, 8, 8))
867        with config.patch({"cpp.simdlen": None}):
868            torch._dynamo.reset()
869            metrics.reset()
870            self.common(fn, (x,))
871
872    def test_slice_scatter_default_end_value(self):
873        # From HF AllenaiLongformerBase.
874        def fn(query, key, window_overlap):
875            batch_size, seq_len, num_heads, head_dim = query.size()
876            assert (
877                seq_len % (window_overlap * 2) == 0
878            ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
879
880            chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
881            diagonal_chunked_attention_scores = key
882            diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
883                (
884                    batch_size * num_heads,
885                    chunks_count + 1,
886                    window_overlap,
887                    window_overlap * 2 + 1,
888                )
889            )
890            diagonal_attention_scores[
891                :, :3, :, window_overlap:
892            ] = diagonal_chunked_attention_scores[
893                :, :, :window_overlap, : window_overlap + 1
894            ]
895            return diagonal_attention_scores
896
897        self.common(
898            fn,
899            (
900                torch.randn(1, 1024, 12, 64),
901                torch.randn(12, 3, 512, 513),
902                256,
903            ),
904        )
905
906    @requires_vectorization
907    @patch("torch.cuda.is_available", lambda: False)
908    def test_to_uint8_rounding_method(self):
909        def fn(x):
910            return x.to(torch.uint8)
911
912        numerical_testsuit = [4.4, 4.5, 4.6, 5.5]
913        for numerical_number in numerical_testsuit:
914            x = torch.ones(17) * numerical_number
915            with config.patch({"cpp.simdlen": None}):
916                torch._dynamo.reset()
917                metrics.reset()
918                self.common(fn, (x,))
919                check_metrics_vec_kernel_count(1)
920
921    @requires_vectorization
922    def _test_decomposed_dequant_relu_quant_helper(self, dtype):
923        def fn(
924            x, scale, zero_point, use_dequant, use_quant, quant_min, quant_max, dtype
925        ):
926            # For quantized_decomposed.dequantize_per_tensor
927            # Refer to torch/ao/quantization/fx/_decomposed.py
928            if use_dequant:
929                x = (x.to(torch.float32) - zero_point) * scale
930
931            x = torch.relu(x)
932
933            # For quantized_decomposed.quantize_per_tensor
934            # Refer to torch/ao/quantization/fx/_decomposed.py
935            if use_quant:
936                inv_scale = 1.0 / scale
937                x = torch.clamp(
938                    torch.round(x * inv_scale) + zero_point, quant_min, quant_max
939                ).to(dtype)
940            return x
941
942        assert dtype in [torch.uint8, torch.int8]
943        quant_min = 0 if dtype == torch.uint8 else -128
944        quant_max = 255 if dtype == torch.uint8 else 127
945
946        use_dequant_list = [False, True]
947        use_quant_list = [False, True]
948        for use_dequant, use_quant in itertools.product(
949            use_dequant_list, use_quant_list
950        ):
951            x = torch.clamp(
952                torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100,
953                quant_min,
954                quant_max,
955            )
956            if use_dequant:
957                x = x.to(dtype)
958            zero_point = 100
959            scale = 0.01
960            with config.patch({"cpp.simdlen": None}):
961                torch._dynamo.reset()
962                metrics.reset()
963                self.common(
964                    fn,
965                    (
966                        x,
967                        scale,
968                        zero_point,
969                        use_dequant,
970                        use_quant,
971                        quant_min,
972                        quant_max,
973                        dtype,
974                    ),
975                )
976                check_metrics_vec_kernel_count(1)
977
978    @requires_vectorization
979    def test_decomposed_dequant_relu_quant_uint8(self):
980        self._test_decomposed_dequant_relu_quant_helper(torch.uint8)
981
982    @requires_vectorization
983    def test_decomposed_dequant_relu_quant_int8(self):
984        self._test_decomposed_dequant_relu_quant_helper(torch.int8)
985
986    def _test_dequant_quant_lowering_helper(self, dtype):
987        def fn(
988            x, scale, zero_point, use_dequant, use_quant, quant_min, quant_max, dtype
989        ):
990            if use_dequant:
991                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
992                    x, scale, zero_point, quant_min, quant_max, dtype
993                )
994
995            x = torch.relu(x)
996
997            if use_quant:
998                x = torch.ops.quantized_decomposed.quantize_per_tensor(
999                    x, scale, zero_point, quant_min, quant_max, dtype
1000                )
1001            return x
1002
1003        use_dequant_list = [False, True]
1004        use_quant_list = [False, True]
1005        use_tensor_overload_list = [False, True]
1006
1007        assert dtype in [torch.uint8, torch.int8]
1008        quant_min = 0 if dtype == torch.uint8 else -128
1009        quant_max = 255 if dtype == torch.uint8 else 127
1010
1011        for use_dequant, use_quant, use_tensor_overload in itertools.product(
1012            use_dequant_list, use_quant_list, use_tensor_overload_list
1013        ):
1014            x = torch.clamp(
1015                torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100,
1016                quant_min,
1017                quant_max,
1018            )
1019            if use_dequant:
1020                x = x.to(dtype)
1021            zero_point = 100
1022            scale = 0.01
1023            if use_tensor_overload:
1024                zero_point = torch.tensor(zero_point, dtype=torch.int64)
1025                scale = torch.tensor(scale)
1026            with config.patch({"cpp.simdlen": None}):
1027                torch._dynamo.reset()
1028                metrics.reset()
1029                self.common(
1030                    fn,
1031                    (
1032                        x,
1033                        scale,
1034                        zero_point,
1035                        use_dequant,
1036                        use_quant,
1037                        quant_min,
1038                        quant_max,
1039                        dtype,
1040                    ),
1041                )
1042                check_metrics_vec_kernel_count(1)
1043
1044    @requires_vectorization
1045    def test_dequant_quant_lowering_uint8(self):
1046        self._test_dequant_quant_lowering_helper(torch.uint8)
1047
1048    @requires_vectorization
1049    def test_dequant_quant_lowering_int8(self):
1050        self._test_dequant_quant_lowering_helper(torch.int8)
1051
1052    def _test_dequant_maxpool2d_lowering_helper(self, dtype):
1053        def fn(x, scale, zero_point, quant_min, quant_max, dtype):
1054            x = torch.ops.quantized_decomposed.dequantize_per_tensor(
1055                x, scale, zero_point, quant_min, quant_max, dtype
1056            )
1057            max_pool2d_with_indices_default = (
1058                torch.ops.aten.max_pool2d_with_indices.default(
1059                    x, [2, 2], [2, 2], [1, 1]
1060                )[0]
1061            )
1062            return max_pool2d_with_indices_default
1063
1064        assert dtype in [torch.uint8, torch.int8]
1065        quant_min = 0 if dtype == torch.uint8 else -128
1066        quant_max = 255 if dtype == torch.uint8 else 127
1067
1068        use_tensor_overload_list = [False, True]
1069        for use_tensor_overload in use_tensor_overload_list:
1070            x = (
1071                torch.clamp(
1072                    torch.randn((3, 16, 8, 8), dtype=torch.float32) * 100,
1073                    quant_min,
1074                    quant_max,
1075                )
1076                .to(dtype)
1077                .contiguous(memory_format=torch.channels_last)
1078            )
1079            zero_point = 100
1080            scale = 0.01
1081            if use_tensor_overload:
1082                zero_point = torch.tensor(zero_point, dtype=torch.int64)
1083                scale = torch.tensor(scale)
1084            with config.patch({"cpp.simdlen": None}):
1085                torch._dynamo.reset()
1086                metrics.reset()
1087                self.common(fn, (x, scale, zero_point, quant_min, quant_max, dtype))
1088                check_metrics_vec_kernel_count(1)
1089
1090    @requires_vectorization
1091    def test_dequant_maxpool2d_lowering_uint8(self):
1092        self._test_dequant_maxpool2d_lowering_helper(torch.uint8)
1093
1094    @requires_vectorization
1095    def test_dequant_maxpool2d_lowering_int8(self):
1096        self._test_dequant_maxpool2d_lowering_helper(torch.int8)
1097
1098    def _test_tile2d_load_decomposed_dequant_add_relu_quant_helper(self, dtype):
1099        def fn(
1100            x,
1101            scale,
1102            zero_point,
1103            x2,
1104            scale2,
1105            zero_point2,
1106            output_scale,
1107            output_zero_point,
1108            use_dequant,
1109            use_dequant2,
1110            use_quant,
1111            quant_min,
1112            quant_max,
1113            dtype,
1114        ):
1115            if use_dequant:
1116                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
1117                    x, scale, zero_point, quant_min, quant_max, dtype
1118                )
1119            if use_dequant2:
1120                x2 = torch.ops.quantized_decomposed.dequantize_per_tensor(
1121                    x2, scale2, zero_point2, quant_min, quant_max, dtype
1122                )
1123            temp = x + x2
1124            y = torch.relu(temp)
1125
1126            if use_quant:
1127                y = torch.ops.quantized_decomposed.quantize_per_tensor(
1128                    y, output_scale, output_zero_point, quant_min, quant_max, dtype
1129                )
1130            return y.contiguous()
1131
1132        assert dtype in [torch.uint8, torch.int8]
1133        quant_min = 0 if dtype == torch.uint8 else -128
1134        quant_max = 255 if dtype == torch.uint8 else 127
1135
1136        use_dequant_list = [False, True]
1137        use_dequant_list2 = [False, True]
1138        use_quant_list = [False, True]
1139
1140        for use_dequant, use_dequant2, use_quant in itertools.product(
1141            use_dequant_list, use_dequant_list2, use_quant_list
1142        ):
1143            x = torch.clamp(
1144                torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100,
1145                quant_min,
1146                quant_max,
1147            ).contiguous(memory_format=torch.channels_last)
1148            x2 = torch.clamp(
1149                torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100,
1150                quant_min,
1151                quant_max,
1152            ).contiguous(memory_format=torch.channels_last)
1153            if use_dequant:
1154                x = x.to(dtype).contiguous(memory_format=torch.channels_last)
1155            if use_dequant2:
1156                x2 = x2.to(dtype).contiguous(memory_format=torch.channels_last)
1157            zero_point = 1
1158            scale = 0.01
1159            zero_point2 = 2
1160            scale2 = 0.02
1161            output_zero_point = 3
1162            output_scale = 0.03
1163            with config.patch({"cpp.simdlen": None}):
1164                torch._dynamo.reset()
1165                metrics.reset()
1166                self.common(
1167                    fn,
1168                    (
1169                        x,
1170                        scale,
1171                        zero_point,
1172                        x2,
1173                        scale2,
1174                        zero_point2,
1175                        output_scale,
1176                        output_zero_point,
1177                        use_dequant,
1178                        use_dequant2,
1179                        use_quant,
1180                        quant_min,
1181                        quant_max,
1182                        dtype,
1183                    ),
1184                )
1185                check_metrics_vec_kernel_count(2)
1186
1187    @requires_vectorization
1188    def test_tile2d_load_decomposed_dequant_add_relu_quant_uint8(self):
1189        self._test_tile2d_load_decomposed_dequant_add_relu_quant_helper(torch.uint8)
1190
1191    @requires_vectorization
1192    def test_tile2d_load_decomposed_dequant_add_relu_quant_int8(self):
1193        self._test_tile2d_load_decomposed_dequant_add_relu_quant_helper(torch.int8)
1194
1195    @requires_vectorization
1196    def _test_per_tensor_fake_quant_helper(self, dtype):
1197        def fn(input, scales, zero_points, quant_min, quant_max, dtype):
1198            input = torch.ops.quantized_decomposed.quantize_per_tensor(
1199                input, scales, zero_points, quant_min, quant_max, dtype
1200            )
1201            input = torch.ops.quantized_decomposed.dequantize_per_tensor(
1202                input, scales, zero_points, quant_min, quant_max, dtype
1203            )
1204            return input
1205
1206        use_tensor_overload_list = [False, True]
1207        for use_tensor_overload in use_tensor_overload_list:
1208            assert dtype in [torch.uint8, torch.int8]
1209            quant_min = 0 if dtype == torch.uint8 else -128
1210            quant_max = 255 if dtype == torch.uint8 else 127
1211            x = torch.clamp(
1212                torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100,
1213                quant_min,
1214                quant_max,
1215            )
1216            zero_point = 100
1217            scale = 0.01
1218            if use_tensor_overload:
1219                zero_point = torch.tensor(zero_point, dtype=torch.int64)
1220                scale = torch.tensor(scale)
1221            with config.patch({"cpp.simdlen": None}):
1222                torch._dynamo.reset()
1223                metrics.reset()
1224                self.common(fn, (x, scale, zero_point, quant_min, quant_max, dtype))
1225                assert metrics.generated_cpp_vec_kernel_count == 1
1226
1227    @requires_vectorization
1228    def test_per_tensor_fake_quant_uint8(self):
1229        self._test_per_tensor_fake_quant_helper(torch.uint8)
1230
1231    @requires_vectorization
1232    def test_per_tensor_fake_quant_int8(self):
1233        self._test_per_tensor_fake_quant_helper(torch.int8)
1234
1235    def _test_per_channel_fake_quant_helper(self, dtype, input_dtype=torch.float32):
1236        def fn(input, scales, zero_points, axis, quant_min, quant_max, dtype):
1237            input = torch.ops.quantized_decomposed.quantize_per_channel(
1238                input, scales, zero_points, axis, quant_min, quant_max, dtype
1239            )
1240            input = torch.ops.quantized_decomposed.dequantize_per_channel(
1241                input, scales, zero_points, axis, quant_min, quant_max, dtype
1242            )
1243            return input
1244
1245        assert dtype in [torch.uint8, torch.int8]
1246        quant_min = 0 if dtype == torch.uint8 else -128
1247        quant_max = 255 if dtype == torch.uint8 else 127
1248        x = torch.clamp(
1249            torch.randn((1, 3, 224, 224), dtype=torch.float32) * 100,
1250            quant_min,
1251            quant_max,
1252        )
1253        if input_dtype != torch.float32:
1254            x = x.to(dtype=input_dtype)
1255        scales = torch.ones((3,))
1256        zero_points = torch.zeros((3,))
1257        axis = 1
1258        with config.patch({"cpp.simdlen": None}):
1259            torch._dynamo.reset()
1260            metrics.reset()
1261            self.common(fn, (x, scales, zero_points, axis, quant_min, quant_max, dtype))
1262            check_metrics_vec_kernel_count(1)
1263
1264    @requires_vectorization
1265    def test_per_channel_fake_quant_uint8(self):
1266        self._test_per_channel_fake_quant_helper(torch.uint8)
1267
1268    @requires_vectorization
1269    def test_per_channel_fake_quant_module_uint8(self):
1270        class Mod(torch.nn.Module):
1271            def __init__(self):
1272                super().__init__()
1273                self.scales = torch.ones((3,)).to(torch.float64)
1274                self.zero_points = torch.zeros((3,)).to(torch.int64)
1275                self.axis = 1
1276                self.quant_min = 0
1277                self.quant_max = 255
1278                self.dtype = torch.uint8
1279
1280            def forward(self, input):
1281                input = torch.ops.quantized_decomposed.quantize_per_channel(
1282                    input,
1283                    self.scales,
1284                    self.zero_points,
1285                    self.axis,
1286                    self.quant_min,
1287                    self.quant_max,
1288                    self.dtype,
1289                )
1290                input = torch.ops.quantized_decomposed.dequantize_per_channel(
1291                    input,
1292                    self.scales,
1293                    self.zero_points,
1294                    self.axis,
1295                    self.quant_min,
1296                    self.quant_max,
1297                    self.dtype,
1298                )
1299                return input
1300
1301        m = Mod().eval()
1302        x = torch.clamp(
1303            torch.randn((1, 3, 224, 224), dtype=torch.float32) * 100,
1304            0,
1305            255,
1306        )
1307        with config.patch({"cpp.simdlen": None}):
1308            torch._dynamo.reset()
1309            metrics.reset()
1310            self.common(m, (x,))
1311            assert metrics.generated_cpp_vec_kernel_count == 1
1312
1313    @requires_vectorization
1314    def test_per_channel_fake_quant_int8(self):
1315        self._test_per_channel_fake_quant_helper(torch.int8)
1316
1317    @requires_vectorization
1318    def test_per_channel_fake_quant_uint8_bf16_input(self):
1319        self._test_per_channel_fake_quant_helper(
1320            torch.uint8, input_dtype=torch.bfloat16
1321        )
1322
1323    @requires_vectorization
1324    def test_per_channel_fake_quant_int8_bf16_input(self):
1325        self._test_per_channel_fake_quant_helper(torch.int8, input_dtype=torch.bfloat16)
1326
1327    def _test_non_contiguous_load_buf_quant_helper(self, dtype):
1328        def fn(
1329            x1,
1330            x2,
1331            groups,
1332            quant_min,
1333            quant_max,
1334            dtype,
1335        ):
1336            x = torch.cat((x1, x2), dim=1)
1337            batchsize, num_channels, height, width = x.size()
1338            channels_per_group = num_channels // groups
1339            x = torch.ops.quantized_decomposed.dequantize_per_tensor(
1340                x, 1.0, 0, quant_min, quant_max, dtype
1341            )
1342            x = x.view(batchsize, groups, channels_per_group, height, width)
1343            x = torch.ops.quantized_decomposed.quantize_per_tensor(
1344                x, 1.0, 0, quant_min, quant_max, dtype
1345            )
1346            x = torch.ops.quantized_decomposed.dequantize_per_tensor(
1347                x, 1.0, 0, quant_min, quant_max, dtype
1348            )
1349            x = torch.transpose(x, 1, 2).contiguous()
1350            x = x.view(batchsize, num_channels, height, width)
1351            return x
1352
1353        assert dtype in [torch.uint8, torch.int8]
1354        quant_min = 0 if dtype == torch.uint8 else -128
1355        quant_max = 255 if dtype == torch.uint8 else 127
1356
1357        x = torch.randint(0, 8, (1, 116, 28, 28), dtype=dtype).contiguous(
1358            memory_format=torch.channels_last
1359        )
1360        x2 = torch.randint(0, 8, (1, 116, 28, 28), dtype=dtype).contiguous(
1361            memory_format=torch.channels_last
1362        )
1363
1364        with config.patch({"cpp.simdlen": None}):
1365            torch._dynamo.reset()
1366            metrics.reset()
1367            self.common(
1368                fn,
1369                (
1370                    x,
1371                    x2,
1372                    2,
1373                    quant_min,
1374                    quant_max,
1375                    dtype,
1376                ),
1377            )
1378            check_metrics_vec_kernel_count(2)
1379
1380    @requires_vectorization
1381    def test_non_contiguous_load_buf_quant_uint8(self):
1382        self._test_non_contiguous_load_buf_quant_helper(torch.uint8)
1383
1384    @requires_vectorization
1385    def test_non_contiguous_load_buf_quant_int8(self):
1386        self._test_non_contiguous_load_buf_quant_helper(torch.int8)
1387
1388    def _test_tile2d_store_channel_shuffle_cl_quant_output_helper(self, dtype):
1389        def channel_shuffle(
1390            x, groups, output_scale, output_zero_point, quant_min, quant_max, dtype
1391        ):
1392            batchsize, num_channels, height, width = x.size()
1393            channels_per_group = num_channels // groups
1394            x = x.view(batchsize, groups, channels_per_group, height, width)
1395            x = torch.transpose(x, 1, 2).contiguous()
1396            x = x.view(batchsize, -1, height, width)
1397            x = torch.ops.quantized_decomposed.quantize_per_tensor(
1398                x, output_scale, output_zero_point, quant_min, quant_max, dtype
1399            )
1400            return x.contiguous(memory_format=torch.channels_last)
1401
1402        assert dtype in [torch.uint8, torch.int8]
1403        quant_min = 0 if dtype == torch.uint8 else -128
1404        quant_max = 255 if dtype == torch.uint8 else 127
1405
1406        with config.patch({"cpp.simdlen": None}):
1407            torch._dynamo.reset()
1408            metrics.reset()
1409            x = torch.randn(64, 58, 28, 28)
1410            output_zero_point = 3
1411            output_scale = 0.03
1412            self.common(
1413                channel_shuffle,
1414                (x, 2, output_scale, output_zero_point, quant_min, quant_max, dtype),
1415            )
1416            check_metrics_vec_kernel_count(2)
1417
1418    @requires_vectorization
1419    def test_tile2d_store_channel_shuffle_cl_quant_output_uint8(self):
1420        self._test_tile2d_store_channel_shuffle_cl_quant_output_helper(torch.uint8)
1421
1422    @requires_vectorization
1423    def test_tile2d_store_channel_shuffle_cl_quant_output_int8(self):
1424        self._test_tile2d_store_channel_shuffle_cl_quant_output_helper(torch.int8)
1425
1426    def _test_dequant_relu_quant_dequant_relu_quant_lowering_helper(self, dtype):
1427        def fn(
1428            x,
1429            scale,
1430            zero_point,
1431            scale2,
1432            zero_point2,
1433            scale3,
1434            zero_point3,
1435            quant_min,
1436            quant_max,
1437            dtype,
1438        ):
1439            x = torch.ops.quantized_decomposed.dequantize_per_tensor(
1440                x, scale, zero_point, quant_min, quant_max, dtype
1441            )
1442            x = torch.relu(x)
1443            x = torch.ops.quantized_decomposed.quantize_per_tensor(
1444                x, scale2, zero_point2, quant_min, quant_max, dtype
1445            )
1446            x = torch.ops.quantized_decomposed.dequantize_per_tensor(
1447                x, scale2, zero_point2, quant_min, quant_max, dtype
1448            )
1449            x = torch.relu(x)
1450            x = torch.ops.quantized_decomposed.quantize_per_tensor(
1451                x, scale3, zero_point3, quant_min, quant_max, dtype
1452            )
1453            return x
1454
1455        assert dtype in [torch.uint8, torch.int8]
1456        quant_min = 0 if dtype == torch.uint8 else -128
1457        quant_max = 255 if dtype == torch.uint8 else 127
1458
1459        for use_tensor_overload in [True, False]:
1460            x = torch.clamp(
1461                torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100,
1462                quant_min,
1463                quant_max,
1464            ).to(dtype)
1465            zero_point_list = [100, 101, 102]
1466            scale_list = [0.01, 0.02, 0.03]
1467            if use_tensor_overload:
1468                for i in range(len(zero_point_list)):
1469                    zero_point_list[i] = torch.tensor(
1470                        zero_point_list[i], dtype=torch.int64
1471                    )
1472                    scale_list[i] = torch.tensor(scale_list[i])
1473            zero_point, zero_point2, zero_point3 = zero_point_list
1474            scale, scale2, scale3 = scale_list
1475            with config.patch({"cpp.simdlen": None}):
1476                torch._dynamo.reset()
1477                metrics.reset()
1478                self.common(
1479                    fn,
1480                    (
1481                        x,
1482                        scale,
1483                        zero_point,
1484                        scale2,
1485                        zero_point2,
1486                        scale3,
1487                        zero_point3,
1488                        quant_min,
1489                        quant_max,
1490                        dtype,
1491                    ),
1492                    rtol=1e-2,
1493                    atol=1e-2,
1494                )
1495                check_metrics_vec_kernel_count(1)
1496
1497    @requires_vectorization
1498    def test_dequant_relu_quant_dequant_relu_quant_lowering_uint8(self):
1499        self._test_dequant_relu_quant_dequant_relu_quant_lowering_helper(torch.uint8)
1500
1501    @requires_vectorization
1502    def test_dequant_relu_quant_dequant_relu_quant_lowering_int8(self):
1503        self._test_dequant_relu_quant_dequant_relu_quant_lowering_helper(torch.int8)
1504
1505    def test_inplace_add_alpha(self):
1506        def fn(x, y):
1507            aten.add_.Tensor(x, y, alpha=0.55)
1508            return (x,)
1509
1510        x1 = torch.zeros(10)
1511        x2 = torch.zeros(10)
1512        x3 = torch.zeros(10)
1513        y = torch.randn(10)
1514        fn_fx = make_fx(fn)(x1, y)
1515        fn_compiled = compile_fx_inner(fn_fx, [x1, y])
1516        fn(x2, y)
1517        fn_compiled([x3, y])
1518        assert same(x2, x3)
1519
1520    def test_int_div(self):
1521        def fn(x, y):
1522            s3 = x.size(1)
1523            a = torch.zeros((1 + s3) // 2)
1524            a += y
1525            return a, s3
1526
1527        p0 = torch.randint(5, (1, 8))
1528        p1 = torch.randn(1)
1529        self.common(fn, (p0, p1))
1530
1531    def test_no_op_squeeze(self):
1532        @torch._dynamo.optimize("inductor")
1533        def forward(arg0_1):
1534            return torch.ops.aten.squeeze.dim(arg0_1, 1)
1535
1536        x = torch.randn((10, 20))
1537        self.common(forward, (x,))
1538
1539    def test_parallel_num_threads(self):
1540        @torch._dynamo.optimize("inductor")
1541        def fn(x1, x2):
1542            return x1 + x2
1543
1544        x1 = torch.randn((10, 20))
1545        x2 = torch.randn((10, 20))
1546        with set_num_threads(1):
1547            assert same(x1 + x2, fn(x1, x2))
1548        with set_num_threads(4):
1549            assert same(x1 + x2, fn(x1, x2))
1550
1551    @patch("torch.cuda.is_available", lambda: False)
1552    def test_timed_cpu_only(self):
1553        timed(lambda: torch.randn(10), ())
1554
1555    def test_complex_memory_overlap(self):
1556        dense = torch.zeros(64, 32)
1557        self.assertFalse(complex_memory_overlap(dense))
1558        self.assertFalse(complex_memory_overlap(dense.t()))
1559
1560        strided = dense.split(4, dim=1)
1561        self.assertFalse(complex_memory_overlap(strided[0]))
1562        self.assertFalse(complex_memory_overlap(strided[0].t()))
1563
1564        unsqueezed = dense.unsqueeze(1)
1565        self.assertFalse(complex_memory_overlap(unsqueezed))
1566        self.assertFalse(complex_memory_overlap(unsqueezed.permute(1, 2, 0)))
1567
1568        gathered = dense.index_select(0, torch.IntTensor([1, 0, 1]))
1569        self.assertFalse(complex_memory_overlap(gathered))
1570        self.assertFalse(complex_memory_overlap(gathered.t()))
1571
1572    @requires_vectorization
1573    def test_vec_dynamic_shapes(self):
1574        def fn(x):
1575            return torch.softmax(x, -1)
1576
1577        value = torch.randn((2, 10))
1578        with config.patch({"cpp.simdlen": None}):
1579            torch._dynamo.reset()
1580            metrics.reset()
1581            self.common(fn, (value,))
1582
1583    @unittest.skipIf(
1584        platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(),
1585        "Does not support vectorization or not x86_64 machine",
1586    )
1587    @patch("torch.cuda.is_available", lambda: False)
1588    def test_auto_simd(self):
1589        vec_avx512 = codecache.supported_vec_isa_list[0]
1590        vec_avx2 = codecache.supported_vec_isa_list[1]
1591        self.assertTrue(vec_avx512.bit_width() == 512)
1592        self.assertTrue(vec_avx2.bit_width() == 256)
1593        self.assertTrue(vec_avx512.nelements() == 16)
1594        self.assertTrue(vec_avx2.nelements() == 8)
1595        self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
1596        self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
1597
1598        with config.patch({"cpp.simdlen": None}):
1599            isa = codecache.pick_vec_isa()
1600            if vec_avx512 in codecache.valid_vec_isa_list():
1601                self.assertTrue(isa == vec_avx512)
1602            else:
1603                self.assertTrue(isa == vec_avx2)
1604
1605        with config.patch({"cpp.simdlen": 0}):
1606            isa = codecache.pick_vec_isa()
1607            self.assertFalse(isa)
1608
1609        with config.patch({"cpp.simdlen": 1}):
1610            isa = codecache.pick_vec_isa()
1611            self.assertFalse(isa)
1612
1613        with config.patch({"cpp.simdlen": 257}):
1614            isa = codecache.pick_vec_isa()
1615            self.assertFalse(isa)
1616
1617        with config.patch({"cpp.simdlen": 513}):
1618            isa_list = codecache.valid_vec_isa_list()
1619            if vec_avx512 in isa_list:
1620                self.assertFalse(isa)
1621
1622        with config.patch({"cpp.simdlen": 512}):
1623            isa_list = codecache.valid_vec_isa_list()
1624            if vec_avx512 in isa_list:
1625                isa = codecache.pick_vec_isa()
1626                self.assertTrue(isa == vec_avx512)
1627
1628        with config.patch({"cpp.simdlen": 256}):
1629            isa_list = codecache.valid_vec_isa_list()
1630            if vec_avx2 in isa_list:
1631                isa = codecache.pick_vec_isa()
1632                self.assertTrue(isa == vec_avx2)
1633
1634    @requires_vectorization
1635    @patch("torch.cuda.is_available", lambda: False)
1636    def test_masked_fill_softmax(self):
1637        def fn(value, mask):
1638            mask = mask.to(torch.bool)
1639            x = torch.masked_fill(value, mask, -33.0)
1640            return torch.softmax(x, -1)
1641
1642        for dtype in vec_dtypes:
1643            value = torch.randn((2, 17), dtype=dtype)
1644            mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8)
1645            with config.patch({"cpp.simdlen": None}):
1646                for cpp_wrapper_flag in [True, False]:
1647                    with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
1648                        torch._dynamo.reset()
1649                        metrics.reset()
1650                        self.common(fn, (value, mask))
1651                        assert metrics.generated_cpp_vec_kernel_count >= 1
1652
1653    def test_channels_last_view_as_complex(self):
1654        # https://github.com/pytorch/pytorch/issues/122448#issuecomment-2046169554
1655
1656        def reduce_example(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1657            """Applies the rotary embedding to the query and key tensors."""
1658            x_out = torch.view_as_complex(torch.stack([x.float(), y.float()], dim=-1))
1659            return x_out
1660
1661        args = [torch.randn(1, 1, 1, 128), torch.randn(1, 1, 1, 128)]
1662        expected = reduce_example(*args)
1663        actual = torch.compile(reduce_example, fullgraph=True)(*args)
1664        self.assertEqual(expected, actual)
1665
1666    def test_load_same_bool_tensor_twice(self):
1667        @torch._dynamo.optimize("inductor")
1668        def fn(a, b):
1669            x = torch.masked_fill(a, b, -33.0)
1670            y = torch.masked_fill(a, b, -33.0)
1671            return x, y
1672
1673        value = torch.randn((2, 17))
1674        mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool)
1675        fn(value, mask)
1676
1677    def test_cpu_vec_cosim(self):
1678        cpp_vec_op_list = []
1679        cpp_op_list = []
1680
1681        for k, v in CppVecOverrides.__dict__.items():
1682            if isinstance(v, staticmethod):
1683                cpp_vec_op_list.append(k)
1684        for k, v in CppOverrides.__dict__.items():
1685            if isinstance(v, staticmethod):
1686                cpp_op_list.append(k)
1687
1688        diff = [
1689            "airy_ai",
1690            "bessel_j0",
1691            "bessel_j1",
1692            "bessel_y0",
1693            "bessel_y1",
1694            "modified_bessel_i0",
1695            "modified_bessel_i1",
1696            "modified_bessel_k0",
1697            "modified_bessel_k1",
1698            "scaled_modified_bessel_k0",
1699            "scaled_modified_bessel_k1",
1700            "spherical_bessel_j0",
1701            "i1",
1702            "i1e",
1703            "ndtr",
1704            "ndtri",
1705            "log_ndtr",
1706            "erfcx",
1707            "gammainc",
1708            "gammaincc",
1709            "igamma",
1710            "igammac",
1711            "polygamma",
1712            "zeta",
1713            "shifted_chebyshev_polynomial_u",
1714            "chebyshev_polynomial_u",
1715            "chebyshev_polynomial_t",
1716            "shifted_chebyshev_polynomial_w",
1717            "chebyshev_polynomial_w",
1718            "shifted_chebyshev_polynomial_t",
1719            "chebyshev_polynomial_v",
1720            "shifted_chebyshev_polynomial_v",
1721            "hermite_polynomial_he",
1722            "laguerre_polynomial_l",
1723            "hermite_polynomial_h",
1724            "legendre_polynomial_p",
1725            "constant",
1726            "index_expr",
1727            "signbit",
1728            "isinf",
1729            "frexp",
1730            "mod",
1731            "masked",
1732            "randn",
1733            "isnan",
1734            "rand",
1735            "randint64",
1736            "logical_and",
1737            "logical_not",
1738            "logical_or",
1739            "logical_xor",
1740            "bitwise_and",
1741            "bitwise_left_shift",
1742            "bitwise_not",
1743            "bitwise_right_shift",
1744            "bitwise_or",
1745            "bitwise_xor",
1746            "to_dtype_bitcast",
1747        ]
1748        union = {*cpp_vec_op_list, *diff}
1749        self.assertTrue(
1750            set(cpp_op_list).issubset(union), f"unexpected: {set(cpp_op_list) - union}"
1751        )
1752
1753    def test_atomic_add_lowp_fp(self):
1754        def fn(test_args):
1755            res = torch.gather(**test_args)
1756            return res
1757
1758        for dtype in _lowp_fp_dtypes:
1759            input_tensor_for_ref = torch.tensor(
1760                [[3.0, -5.0]], dtype=dtype, requires_grad=True
1761            )
1762            input_tensor_for_opt = torch.tensor(
1763                [[3.0, -5.0]], dtype=dtype, requires_grad=True
1764            )
1765
1766            test_args_for_ref = {
1767                "input": input_tensor_for_ref,
1768                "dim": 1,
1769                "index": torch.tensor([[1]]),
1770            }
1771            test_args_for_opt = {
1772                "input": input_tensor_for_opt,
1773                "dim": 1,
1774                "index": torch.tensor([[1]]),
1775            }
1776
1777            opt_fn = torch.compile(fn)
1778
1779            ref_fwd = fn(test_args_for_ref)
1780            res_fwd = opt_fn(test_args_for_opt)
1781            self.assertEqual(res_fwd, ref_fwd)
1782
1783            torch.manual_seed(1)
1784            bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=dtype)
1785            torch.manual_seed(1)
1786            bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=dtype)
1787            self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt)
1788
1789            ref_fwd.backward(bwd_tensor_for_ref)
1790            res_fwd.backward(bwd_tensor_for_opt)
1791
1792            ref_grad = test_args_for_ref["input"].grad
1793            res_grad = test_args_for_opt["input"].grad
1794            self.assertEqual(ref_grad, res_grad)
1795
1796    def test_meta_device(self):
1797        @torch.compile(fullgraph=True)
1798        def fn():
1799            x = torch.ops.aten.empty.memory_format(
1800                [1024, 128, 128],
1801                dtype=torch.float16,
1802                device="meta",
1803                pin_memory=False,
1804            )
1805            return x.sin() + 1
1806
1807        self.assertEqual(fn().shape, [1024, 128, 128])
1808
1809    def test_decomposed_fake_quant_per_channel(self):
1810        def fq(input, scales, zero_points, axis, quant_min, quant_max):
1811            res = torch.fake_quantize_per_channel_affine(
1812                input, scales, zero_points, axis, quant_min, quant_max
1813            )
1814            return res
1815
1816        def qdq(input, scales, zero_points, axis, quant_min, quant_max):
1817            res = torch.ops.quantized_decomposed.fake_quant_per_channel(
1818                input, scales, zero_points, axis, quant_min, quant_max
1819            )
1820            return res
1821
1822        def run_eager_aten_fake_quant(
1823            input, scales, zero_points, axis, quant_min, quant_max
1824        ):
1825            input.grad = None
1826            res = fq(input, scales, zero_points, axis, quant_min, quant_max)
1827            res.sum().backward()
1828            return res, input.grad
1829
1830        def run_eager_decomposed_fake_quant(
1831            input, scales, zero_points, axis, quant_min, quant_max
1832        ):
1833            input.grad = None
1834            res = qdq(input, scales, zero_points, axis, quant_min, quant_max)
1835            res.sum().backward()
1836            return res, input.grad
1837
1838        def run_compile_decomposed_fake_quant(
1839            input, scales, zero_points, axis, quant_min, quant_max
1840        ):
1841            input.grad = None
1842            compiled_qdq = torch.compile(qdq)
1843            res = compiled_qdq(input, scales, zero_points, axis, quant_min, quant_max)
1844            res.sum().backward()
1845            return res, input.grad
1846
1847        input = torch.randn(2, 3, 224, 224)
1848        input[1, 2, 3, 4] = 257
1849        input.requires_grad_()
1850        scales = torch.ones((3,))
1851        zero_points = torch.zeros((3,))
1852        axis = 1
1853        quant_min = -128
1854        quant_max = 127
1855
1856        aten_input = copy.deepcopy(input)
1857        compiler_input = copy.deepcopy(input)
1858
1859        res_aten_eager, input_grad_aten_eager = run_eager_aten_fake_quant(
1860            aten_input, scales, zero_points, axis, quant_min, quant_max
1861        )
1862        res_decomp_eager, input_grad_decomp_eager = run_eager_decomposed_fake_quant(
1863            input, scales, zero_points, axis, quant_min, quant_max
1864        )
1865        res, input_grad = run_compile_decomposed_fake_quant(
1866            compiler_input, scales, zero_points, axis, quant_min, quant_max
1867        )
1868
1869        self.assertEqual(res_aten_eager, res)
1870        self.assertEqual(res_decomp_eager, res)
1871        self.assertEqual(input_grad_aten_eager, input_grad)
1872        self.assertEqual(input_grad_decomp_eager, input_grad)
1873        self.assertEqual(input_grad[1, 2, 3, 4], torch.tensor(0.0))
1874        # For forward and backward kernel
1875        check_metrics_vec_kernel_count(2)
1876
1877    @requires_vectorization
1878    def test_ops_masked_with_bool_input(self):
1879        x = torch.zeros(129, dtype=torch.bool)
1880        size = [2, 3]
1881        res_aten_eager = torch.constant_pad_nd(x, size)
1882        cfn = torch.compile(torch.constant_pad_nd)
1883        res = cfn(x, size)
1884        self.assertEqual(res_aten_eager, res)
1885        check_metrics_vec_kernel_count(1)
1886
1887    def test_bitwise_right_shift(self):
1888        x = torch.randint(-1, 0, (1, 1, 1), device="cpu", dtype=torch.int64)
1889        bit_num = 31
1890        res_aten_eager = torch.bitwise_right_shift(x, bit_num)
1891        cfn = torch.compile(torch.bitwise_right_shift)
1892        res = cfn(x, bit_num)
1893        self.assertEqual(res_aten_eager, res)
1894
1895    @patch("torch.cuda.is_available", lambda: False)
1896    def test_scatter_using_atomic_add(self):
1897        def fn(a, dim, index, b):
1898            return aten.scatter(a, dim, index, b, reduce="add")
1899
1900        inps = (
1901            torch.randn(5, 29, 13),
1902            2,
1903            torch.tensor([[[3, 5, 7, 9]]]),
1904            torch.randn(1, 1, 10),
1905        )
1906
1907        def _internal_check(
1908            _fn,
1909            _inps,
1910            _target_code_check=None,
1911            _target_code_check_not=None,
1912        ):
1913            torch._dynamo.reset()
1914            metrics.reset()
1915            _fn_opt = torch.compile()(_fn)
1916            _, code = run_and_get_cpp_code(_fn_opt, *inps)
1917            if _target_code_check:
1918                FileCheck().check(_target_code_check).run(code)
1919            if _target_code_check_not:
1920                FileCheck().check_not(_target_code_check_not).run(code)
1921
1922            self.assertEqual(
1923                _fn(*_inps),
1924                _fn_opt(*_inps),
1925            )
1926
1927        with config.patch({"cpp.fallback_scatter_reduce_sum": False}):
1928            _internal_check(fn, inps, "atomic_add")
1929
1930        with config.patch({"cpp.fallback_scatter_reduce_sum": True}):
1931            _internal_check(fn, inps, "aten.scatter_reduce_")
1932
1933        if "ATen parallel backend: OpenMP" in torch.__config__.parallel_info():
1934            # Fix https://github.com/pytorch/pytorch/issues/118518
1935            # which fails to change thread number with native thread pool
1936            with set_num_threads(1):
1937                _internal_check(fn, inps, _target_code_check_not="aten.scatter_reduce_")
1938
1939            with config.patch({"cpp.dynamic_threads": True}), set_num_threads(1):
1940                _internal_check(fn, inps, "aten.scatter_reduce_")
1941
1942    @requires_vectorization
1943    @patch("torch.cuda.is_available", lambda: False)
1944    def test_new_vec_op_cpu_only(self):
1945        def fn(x):
1946            return torch.log1p(torch.expm1(torch.erf(x)))
1947
1948        for dtype in vec_dtypes:
1949            torch.manual_seed(0)
1950            x = torch.randn((2, 9), dtype=dtype)
1951            x[0, 0] = torch.nan
1952            x[1, -1] = torch.nan
1953
1954            tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
1955
1956            with config.patch({"cpp.simdlen": None}):
1957                for cpp_wrapper_flag in [True, False]:
1958                    with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
1959                        torch._dynamo.reset()
1960                        metrics.reset()
1961                        self.common(fn, (x,))
1962                        check_metrics_vec_kernel_count(1)
1963
1964    @requires_vectorization
1965    @patch("torch.cuda.is_available", lambda: False)
1966    def test_vec_cpu_only_for_all_available_isa(self):
1967        def fn(x):
1968            return torch.sin(torch.cos(torch.erf(x)))
1969
1970        x = torch.randn((2, 9))
1971        x[0, 0] = torch.nan
1972        x[1, -1] = torch.nan
1973
1974        bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None]
1975        for item in bit_widths:
1976            with config.patch({"cpp.simdlen": item}):
1977                torch._dynamo.reset()
1978                metrics.reset()
1979                self.common(fn, (x,))
1980                check_metrics_vec_kernel_count(1)
1981
1982    @slowTest
1983    @requires_vectorization
1984    @patch("torch.cuda.is_available", lambda: False)
1985    def test__adaptive_avg_pool2d(self):
1986        def wrap_fn(oh, ow):
1987            def fn(x):
1988                return torch._adaptive_avg_pool2d(x, (oh, ow))
1989
1990            return fn
1991
1992        bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()]
1993        ih = [16, 65]
1994        iw = ih
1995        oh = ih
1996        ow = ih
1997        for _ih, _iw, _oh, _ow, _simd_len, dtype in itertools.product(
1998            ih, iw, oh, ow, bit_widths, vec_dtypes
1999        ):
2000            x = torch.randn(2, 3, _ih, _iw, dtype=dtype).to(
2001                memory_format=torch.channels_last
2002            )
2003            _fn = wrap_fn(_oh, _ow)
2004            with config.patch({"cpp.simdlen": _simd_len}):
2005                torch._dynamo.reset()
2006                metrics.reset()
2007                self.common(_fn, (x,))
2008                check_metrics_vec_kernel_count(1)
2009
2010    @requires_vectorization
2011    @patch("torch.cuda.is_available", lambda: False)
2012    def test_vec_logical(self):
2013        def wrap_fn1(op: Callable):
2014            def fn(x: torch.Tensor):
2015                return torch.where(op(x), 1.0, 0.0)
2016
2017            return fn
2018
2019        def wrap_fn2(op: Callable):
2020            def fn(x: torch.Tensor, y: torch.Tensor):
2021                return torch.where(op(x, y), 1.0, 0.0)
2022
2023            return fn
2024
2025        for dtype in vec_dtypes:
2026            x = torch.randn(64, dtype=dtype)
2027            y = torch.randn(64, dtype=dtype)
2028            logical_fns = [
2029                torch.logical_and,
2030                torch.logical_not,
2031                torch.logical_or,
2032                torch.logical_xor,
2033            ]
2034            for logical_fn in logical_fns:
2035                torch._dynamo.reset()
2036                metrics.reset()
2037                if logical_fn == torch.logical_not:
2038                    _fn = wrap_fn1(logical_fn)
2039                    _args = (x,)
2040                else:
2041                    _fn = wrap_fn2(logical_fn)
2042                    _args = (x, y)
2043                self.common(_fn, _args)
2044                check_metrics_vec_kernel_count(1)
2045
2046    @requires_vectorization
2047    @patch("torch.cuda.is_available", lambda: False)
2048    def test_vec_compare_op_cpu_only(self):
2049        def fn(x):
2050            y1 = torch.eq(x, 1.0)
2051            x = torch.where(y1, x, -x)
2052            y2 = torch.ne(x, 0.0)
2053            x = torch.where(y2, x, -x)
2054            y3 = torch.lt(x, 5.0)
2055            x = torch.where(y3, x, x - 1.0)
2056            y4 = torch.gt(x, -2.0)
2057            x = torch.where(y4, x, x + 1.0)
2058            y5 = torch.le(x, 8.0)
2059            x = torch.where(y5, x, x - 1.0)
2060            y6 = torch.ge(x, -3.0)
2061            x = torch.where(y6, x, x + 1.0)
2062            y7 = x == 1.0
2063            x = torch.where(y7, x, -x)
2064            y8 = x != 0.0
2065            x = torch.where(y8, x, -x)
2066            y9 = x < 5.0
2067            x = torch.where(y9, x, x - 1.0)
2068            y10 = x > -2.0
2069            x = torch.where(y10, x, x + 1.0)
2070            y11 = x <= 8.0
2071            x = torch.where(y11, x, x - 1.0)
2072            y12 = x >= -3.0
2073            x = torch.where(y12, x, x + 1.0)
2074            return x
2075
2076        for dtype in vec_dtypes:
2077            x = torch.randn((2, 9), dtype=dtype)
2078
2079            with config.patch({"cpp.simdlen": None}):
2080                torch._dynamo.reset()
2081                metrics.reset()
2082                self.common(fn, (x,))
2083                check_metrics_vec_kernel_count(1)
2084                assert (
2085                    metrics.generated_kernel_count
2086                    - metrics.generated_cpp_vec_kernel_count
2087                ) == 0
2088
2089    def test_skip_cpp_codegen(self):
2090        with config.patch({"disable_cpp_codegen": True}):
2091            inps = torch.ones([20]), torch.rand([20])
2092
2093            def f(x, y):
2094                return x + y + torch.tensor(1)
2095
2096            f_opt = torch.compile()(f)
2097
2098            _, code = run_and_get_cpp_code(f_opt, inps[0], inps[1])
2099            FileCheck().check_not("void kernel").run(code)
2100
2101            self.assertEqual(
2102                f(*inps),
2103                f_opt(*inps),
2104            )
2105
2106            # constant needs to be propagated on fallback
2107            def f(x):
2108                return x[torch.tensor(1) :] * 2
2109
2110            f_opt = torch.compile()(f)
2111            _, code = run_and_get_cpp_code(f_opt, inps[0])
2112            FileCheck().check_not("void kernel").run(code)
2113            self.assertEqual(f_opt(inps[0]), f(inps[0]))
2114
2115            class Model(torch.nn.Module):
2116                def __init__(
2117                    self,
2118                ):
2119                    super().__init__()
2120
2121                def forward(self, v1: torch.Tensor):
2122                    vx = v1.min(dim=1).values
2123                    v2 = torch.randn_like(vx)
2124                    return v2
2125
2126            model = Model()
2127            x = torch.rand(10, 3, 0)
2128            model_f = torch.compile()(model)
2129
2130            self.assertEqual(model(x), model_f(x))
2131
2132    def test_redundant_to_node_elimination_lowp_fp(self):
2133        def fn(x, y):
2134            res = x + y
2135            res = torch.mean(res)
2136            return res
2137
2138        for dtype in _lowp_fp_dtypes:
2139            x = torch.randn((2, 9), dtype=dtype)
2140            y = torch.randn((2, 9), dtype=dtype)
2141
2142            for torch_compile_debug in [True, False]:
2143                with config.patch(
2144                    {"trace.enabled": torch_compile_debug, "cpp.simdlen": None}
2145                ):
2146                    torch._dynamo.reset()
2147                    metrics.reset()
2148                    self.common(fn, (x, y))
2149                    check_metrics_vec_kernel_count(1)
2150
2151    def test_do_not_insert_to_dtype_for_memory_copy_only_kernel(self):
2152        def fn(x):
2153            res = x.clone()
2154            return res
2155
2156        x = torch.randn((100, 100), dtype=torch.bfloat16)
2157
2158        torch._dynamo.reset()
2159        metrics.reset()
2160        self.common(fn, (x,))
2161        assert metrics.cpp_to_dtype_count == 0
2162        check_metrics_vec_kernel_count(1)
2163
2164    def test_insert_to_dtype_count(self):
2165        def fn(x):
2166            res = x.relu()
2167            return res
2168
2169        x = torch.randn((100, 100), dtype=torch.bfloat16)
2170
2171        torch._dynamo.reset()
2172        metrics.reset()
2173        self.common(fn, (x,))
2174        assert metrics.cpp_to_dtype_count == 2
2175        check_metrics_vec_kernel_count(1)
2176
2177    def test_memory_copy_with_fusion(self):
2178        def fn(x):
2179            res = x.relu()
2180            x.copy_(res)
2181            return (res,)
2182
2183        x = torch.randn((100, 100), dtype=torch.bfloat16)
2184
2185        torch._dynamo.reset()
2186        metrics.reset()
2187        self.common(fn, (x,))
2188        assert metrics.cpp_to_dtype_count == 2
2189        check_metrics_vec_kernel_count(1)
2190
2191    @requires_vectorization
2192    @patch("torch.cuda.is_available", lambda: False)
2193    def test_cpp_vec_constant_checker(self):
2194        _graph: torch.fx.Graph = torch.fx.Graph()
2195        a: torch.fx.Node = _graph.create_node("placeholder", "ops")
2196        iv: torch.fx.Node = _graph.create_node("placeholder", "iv")
2197        fv: torch.fx.Node = _graph.create_node("placeholder", "fv")
2198        b: torch.fx.Node = _graph.create_node(
2199            "call_method",
2200            "constant",
2201            args=(
2202                a,
2203                iv,
2204                torch.int64,
2205            ),
2206        )
2207        c: torch.fx.Node = _graph.create_node(
2208            "call_method",
2209            "constant",
2210            args=(
2211                a,
2212                fv,
2213                torch.double,
2214            ),
2215        )
2216        d: torch.fx.Node = _graph.create_node(
2217            "call_method",
2218            "ge",
2219            args=(
2220                a,
2221                b,
2222                b,
2223            ),
2224        )
2225        _graph.output((d, c))
2226
2227        def get_index():
2228            return ""
2229
2230        submodules = {"get_index": get_index}
2231
2232        graph_lowering = GraphLowering(
2233            torch.fx.GraphModule(submodules, _graph),
2234            shape_env=None,
2235        )
2236
2237        def set_opt_dtype(graph):
2238            for node in graph.nodes:
2239                if node.target == "constant":
2240                    if OptimizationContext.key in node.meta:
2241                        opt_ctx = node.meta[OptimizationContext.key]
2242                    else:
2243                        opt_ctx = OptimizationContext()
2244                    opt_ctx.dtype = node.args[-1]
2245                    node.meta[OptimizationContext.key] = opt_ctx
2246
2247        with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
2248            graph_lowering
2249        ):
2250            # The moset inner loop variable is used in the index_expr
2251            tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
2252            with CppVecKernelChecker(
2253                args=None, num_threads=1, tiling_factor=tiling_factor
2254            ) as vec_checker:
2255                i32_iinfo = np.iinfo(np.int32)
2256                f32_iinfo = np.finfo(np.float32)
2257                set_opt_dtype(_graph)
2258                InterpreterShim(_graph, submodules).run(
2259                    V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max
2260                )
2261                self.assertTrue(vec_checker.simd_vec)
2262
2263                vec_checker.simd_vec = True
2264                set_opt_dtype(_graph)
2265                InterpreterShim(_graph, submodules).run(
2266                    V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min
2267                )
2268                self.assertTrue(vec_checker.simd_vec)
2269
2270                vec_checker.simd_vec = True
2271                set_opt_dtype(_graph)
2272                InterpreterShim(_graph, submodules).run(
2273                    V.get_ops_handler(), i32_iinfo.min, np.inf
2274                )
2275                self.assertTrue(vec_checker.simd_vec)
2276
2277                vec_checker.simd_vec = True
2278                set_opt_dtype(_graph)
2279                InterpreterShim(_graph, submodules).run(
2280                    V.get_ops_handler(), i32_iinfo.min, -np.inf
2281                )
2282                self.assertTrue(vec_checker.simd_vec)
2283
2284                vec_checker.simd_vec = True
2285                set_opt_dtype(_graph)
2286                InterpreterShim(_graph, submodules).run(
2287                    V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min
2288                )
2289                self.assertTrue(vec_checker.simd_vec)
2290
2291                vec_checker.simd_vec = True
2292                set_opt_dtype(_graph)
2293                InterpreterShim(_graph, submodules).run(
2294                    V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max
2295                )
2296                self.assertTrue(vec_checker.simd_vec)
2297
2298                vec_checker.simd_vec = True
2299                set_opt_dtype(_graph)
2300                InterpreterShim(_graph, submodules).run(
2301                    V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5)
2302                )
2303                self.assertFalse(vec_checker.simd_vec)
2304
2305                vec_checker.simd_vec = True
2306                set_opt_dtype(_graph)
2307                InterpreterShim(_graph, submodules).run(
2308                    V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5)
2309                )
2310                self.assertFalse(vec_checker.simd_vec)
2311
2312    @requires_vectorization
2313    @patch("torch.cuda.is_available", lambda: False)
2314    def test_cpp_vec_index_expr_checker(self):
2315        _graph: torch.fx.Graph = torch.fx.Graph()
2316        a: torch.fx.Node = _graph.create_node("placeholder", "ops")
2317        b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=())
2318        c: torch.fx.Node = _graph.create_node(
2319            "call_method",
2320            "index_expr",
2321            args=(
2322                a,
2323                b,
2324                torch.int64,
2325            ),
2326        )
2327        d: torch.fx.Node = _graph.create_node(
2328            "call_method",
2329            "ge",
2330            args=(
2331                a,
2332                c,
2333                c,
2334            ),
2335        )
2336        _graph.output(d)
2337
2338        def get_index():
2339            return ""
2340
2341        submodules = {"get_index": get_index}
2342        graph_lowering = GraphLowering(
2343            torch.fx.GraphModule(submodules, _graph),
2344            shape_env=None,
2345        )
2346        with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
2347            graph_lowering
2348        ):
2349            itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")]
2350
2351            tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
2352            # The most inner loop variable is used in the index_expr
2353            with CppVecKernelChecker(
2354                args=None, num_threads=1, tiling_factor=tiling_factor
2355            ) as vec_checker:
2356
2357                def get_index():
2358                    return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
2359
2360                ranges = [0, 100, 200]
2361                vec_checker.itervars = itervars[:2]
2362                vec_checker.ranges = ranges[:2]
2363                submodules = {"get_index": get_index}
2364                InterpreterShim(_graph, submodules).run(V.get_ops_handler())
2365                self.assertTrue(vec_checker.simd_vec)
2366
2367            # Most inner loop variable irrevalant
2368            with CppVecKernelChecker(
2369                args=None, num_threads=1, tiling_factor=tiling_factor
2370            ) as vec_checker:
2371
2372                def get_index():
2373                    return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
2374
2375                ranges = [0, 100, 200]
2376                vec_checker.itervars = itervars
2377                vec_checker.ranges = ranges
2378                submodules = {"get_index": get_index}
2379                InterpreterShim(_graph, submodules).run(V.get_ops_handler())
2380                self.assertTrue(vec_checker.simd_vec)
2381
2382            i32_iinfo = np.iinfo(np.int32)
2383            _max_value = i32_iinfo.max + 1
2384            ranges = [_max_value, _max_value, _max_value]
2385            # Most inner loop variable irrevalant but max value is greater than
2386            # the max value of INT32
2387            with CppVecKernelChecker(
2388                args=None, num_threads=1, tiling_factor=tiling_factor
2389            ) as vec_checker:
2390
2391                def get_index():
2392                    return itervars[0]
2393
2394                submodules = {"get_index": get_index}
2395                vec_checker.itervars = itervars
2396                vec_checker.ranges = ranges
2397                InterpreterShim(_graph, submodules).run(V.get_ops_handler())
2398                self.assertFalse(vec_checker.simd_vec)
2399
2400            # Most inner loop variable irrevalant but min value is greater than
2401            # the min value of INT32
2402            with CppVecKernelChecker(
2403                args=None, num_threads=1, tiling_factor=tiling_factor
2404            ) as vec_checker:
2405
2406                def get_index():
2407                    return -itervars[0] - 2
2408
2409                submodules = {"get_index": get_index}
2410                vec_checker.itervars = itervars
2411                vec_checker.ranges = ranges
2412                InterpreterShim(_graph, submodules).run(V.get_ops_handler())
2413                self.assertFalse(vec_checker.simd_vec)
2414
2415    @requires_vectorization
2416    @patch("torch.cuda.is_available", lambda: False)
2417    def test_maxpool2d_cpu_only(self):
2418        for dtype in vec_dtypes:
2419            input = torch.randn(26, 32, 112, 112, dtype=dtype).to(
2420                memory_format=torch.channels_last
2421            )
2422            maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
2423
2424            def func(x):
2425                return maxpool(x)
2426
2427            with patch.object(config.cpp, "simdlen", None):
2428                torch._dynamo.reset()
2429                metrics.reset()
2430                self.common(func, (input,))
2431                check_metrics_vec_kernel_count(1)
2432
2433    @requires_vectorization
2434    @patch("torch.cuda.is_available", lambda: False)
2435    def test_maxpool2d_with_pre_loop_collapse_cpu_only(self):
2436        x1 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
2437        x2 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
2438        maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
2439
2440        def func(x1, x2):
2441            y = x1 + x2
2442            return maxpool(y)
2443
2444        with patch.object(config.cpp, "simdlen", None):
2445            torch._dynamo.reset()
2446            metrics.reset()
2447            self.common(func, (x1, x2))
2448            check_metrics_vec_kernel_count(2)
2449
2450    def test_randint_symint_input(self):
2451        # https://github.com/pytorch/pytorch/issues/122405
2452        @torch.compile(fullgraph=True)
2453        def get_traj_idx(lengths: torch.Tensor, num_slices: int) -> torch.Tensor:
2454            return torch.randint(lengths.shape[0], (num_slices,), device=lengths.device)
2455
2456        lengths = torch.zeros(10, dtype=torch.long)
2457        get_traj_idx(lengths, num_slices=4)
2458        lengths = torch.zeros(11, dtype=torch.long)
2459        get_traj_idx(lengths, num_slices=4)
2460
2461    @requires_vectorization
2462    @patch("torch.cuda.is_available", lambda: False)
2463    def test_sign_cpu_only(self):
2464        def fn(x):
2465            return torch.sign(x)
2466
2467        for dtype in vec_dtypes:
2468            x = torch.randn((2, 9), dtype=dtype)
2469            x[0, 0] = torch.nan
2470            x[1, -1] = torch.nan
2471
2472            with config.patch({"cpp.simdlen": None}):
2473                torch._dynamo.reset()
2474                metrics.reset()
2475                self.common(fn, (x,))
2476                check_metrics_vec_kernel_count(1)
2477
2478    @requires_vectorization
2479    @patch("torch.cuda.is_available", lambda: False)
2480    def test_reduction_cpu_only(self):
2481        def fn(x):
2482            return torch.argmax(x, -1)
2483
2484        for dtype in vec_dtypes:
2485            x = torch.randn((10, 10), dtype=dtype)
2486
2487            with config.patch({"cpp.simdlen": None}):
2488                torch._dynamo.reset()
2489                metrics.reset()
2490                self.common(fn, (x,))
2491                assert metrics.generated_cpp_vec_kernel_count == 0
2492
2493    def test_outer_loop_fusion(self):
2494        def fn(x):
2495            max = torch.amax(x, dim=-1, keepdim=True)
2496            return x - max
2497
2498        x = torch.randn(4, 12, 1023, 1022)
2499
2500        with config.patch({"cpp.simdlen": None}):
2501            torch._dynamo.reset()
2502            metrics.reset()
2503            self.common(fn, (x,))
2504            assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1
2505            assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2
2506
2507    def test_argmin(self):
2508        def fn(x):
2509            return torch.argmin(x, -1)
2510
2511        for dtype in vec_dtypes:
2512            x = torch.randn((10, 10), dtype=dtype)
2513            torch._dynamo.reset()
2514            metrics.reset()
2515            self.common(fn, (x,))
2516            assert metrics.generated_cpp_vec_kernel_count == 0
2517
2518    def test_argmax_argmin_with_nan_value(self):
2519        def fn(x):
2520            return torch.argmax(x)
2521
2522        def fn2(x):
2523            return torch.argmin(x)
2524
2525        inputs = [
2526            torch.Tensor([-755832.1250, 100]),
2527            torch.Tensor([-755832.1250, 100, 200]),
2528            torch.Tensor([100, -755832.1250]),
2529            torch.Tensor([100, 200, -755832.1250]),
2530        ]
2531
2532        for x in inputs:
2533            x = x.repeat(16, 16)
2534            x = torch.log1p(x)
2535
2536            # Test argmax
2537            torch._dynamo.reset()
2538            metrics.reset()
2539            self.common(fn, (x,))
2540            assert metrics.generated_cpp_vec_kernel_count == 0
2541
2542            # Test argmin
2543            torch._dynamo.reset()
2544            metrics.reset()
2545            self.common(fn2, (x,))
2546            assert metrics.generated_cpp_vec_kernel_count == 0
2547
2548    # Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not
2549    # supported, the vectorization will not work and skip this test case. For ARM or
2550    # other platforms support, we just need to add the ISA info to the supported_vector_isa
2551    # and include proper aten vectorization head file.
2552    @requires_vectorization
2553    @patch("torch.cuda.is_available", lambda: False)
2554    def test_vec_kernel_cpu_only(self):
2555        def fn(x1, x2):
2556            # Current, there are some limitations as follows.
2557            #   rsqrt:
2558            #     assert [both a fallback and a decomp for same kernel: aten.rsqrt.default]
2559            #   round:
2560            #     couldn't find symbolic meta function/decomposition
2561            #   fmod/logical_and/logic_or:
2562            #     vec kernel has not support to_type
2563            x = torch.abs(x1)
2564            x = torch.sin(x)
2565            x = torch.neg(x)
2566            x = torch.square(x)
2567            x = torch.sigmoid(x)
2568            x = torch.relu(x)
2569            x = torch.cos(x)
2570            x = torch.exp(x)
2571            x = torch.sqrt(x)
2572            x = torch.add(x, x1)
2573            x = torch.sub(x, x2)
2574            x = torch.mul(x, x1)
2575            x = torch.div(x, x1)
2576            x = torch.pow(x, 10)
2577            x = torch.log(x)
2578            x = torch.floor(x)
2579            x = torch.ceil(x)
2580            x = torch.trunc(x)
2581            x = torch.lgamma(x)
2582            x = torch.fmod(x, x2)
2583            x = torch.sign(x)
2584            res = x + x2
2585            return res
2586
2587        for dtype in vec_dtypes:
2588            torch.manual_seed(0)
2589            x1 = torch.randn((5, 20), dtype=dtype)
2590            x2 = torch.randn((5, 20), dtype=dtype)
2591
2592            tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
2593            with config.patch({"cpp.simdlen": 1}):
2594                torch._dynamo.reset()
2595                metrics.reset()
2596                self.common(fn, (x1, x2))
2597                assert metrics.generated_cpp_vec_kernel_count == 0
2598
2599            with config.patch({"cpp.simdlen": None}):
2600                torch._dynamo.reset()
2601                metrics.reset()
2602                self.common(fn, (x1, x2))
2603                check_metrics_vec_kernel_count(1)
2604
2605        with config.patch({"cpp.simdlen": None}):
2606            torch._dynamo.reset()
2607            metrics.reset()
2608            x1 = torch.randn(10, 20).permute(1, 0)
2609            x2 = torch.randn((20, 10))
2610            self.common(fn, (x1, x2))
2611            check_metrics_vec_kernel_count(2)
2612
2613            torch._dynamo.reset()
2614            metrics.reset()
2615            x1 = torch.randn((10, 7))
2616            x2 = torch.randn((10, 7))
2617            self.common(fn, (x1, x2))
2618            check_metrics_vec_kernel_count(1)
2619
2620    @unittest.skipIf(
2621        sys.platform != "linux", "cpp kernel profile only support linux now"
2622    )
2623    @patch("torch.cuda.is_available", lambda: False)
2624    @config.patch({"cpp.enable_kernel_profile": True})
2625    @config.patch({"cpp.descriptive_names": "original_aten"})
2626    def test_cpp_kernel_profile(self):
2627        from torch.profiler import profile
2628
2629        @torch._dynamo.optimize("inductor", nopython=True)
2630        def fn(a, b):
2631            return a + b
2632
2633        a = torch.rand((100,))
2634        b = torch.rand((100,))
2635        with profile() as prof:
2636            fn(a, b)
2637
2638        kernel_profile_events = []
2639        for e in prof.profiler.function_events:
2640            if "cpp_fused_add_0" in e.name:
2641                kernel_profile_events.append(e.name)
2642        assert len(kernel_profile_events) > 0
2643
2644    @requires_vectorization
2645    def test_channel_shuffle_cl_output(self):
2646        """code and shape extracted from shufflenet_v2_x1_0"""
2647
2648        def channel_shuffle(x, groups):
2649            batchsize, num_channels, height, width = x.size()
2650            channels_per_group = num_channels // groups
2651            x = x.view(batchsize, groups, channels_per_group, height, width)
2652            x = torch.transpose(x, 1, 2).contiguous()
2653            x = x.view(batchsize, -1, height, width)
2654            return x.contiguous(memory_format=torch.channels_last)
2655
2656        for simdlen in (None, 256, 1):
2657            with config.patch({"cpp.simdlen": simdlen}):
2658                torch._dynamo.reset()
2659                metrics.reset()
2660                x = torch.randn(64, 58, 28, 28)
2661                self.common(channel_shuffle, (x, 2))
2662                if simdlen != 1:
2663                    check_metrics_vec_kernel_count(2)
2664
2665    @slowTest
2666    @requires_vectorization
2667    def test_transpose_with_norm(self):
2668        """a sub-module from TIMM gmlp_s16_224"""
2669
2670        class Model(torch.nn.Module):
2671            def __init__(self):
2672                super().__init__()
2673                self.linear = torch.nn.Linear(
2674                    in_features=256, out_features=1536, bias=True
2675                )
2676                self.act = torch.nn.GELU()
2677                self.norm = torch.nn.LayerNorm(768)
2678                self.proj = torch.nn.Linear(196, 196)
2679                self.fc = torch.nn.Linear(in_features=768, out_features=256, bias=True)
2680
2681            def forward(self, x):
2682                x = self.linear(x)
2683                x = self.act(x)
2684                u, v = x.chunk(2, dim=-1)
2685                v = self.norm(v)
2686                v = self.proj(v.transpose(-1, -2))
2687                y = u * v.transpose(-1, -2)
2688                return self.fc(y)
2689
2690        x = torch.randn(128, 196, 256)
2691        for simdlen in (None, 256, 1):
2692            with config.patch({"cpp.simdlen": simdlen}):
2693                for eval_mode in [True, False]:
2694                    torch._dynamo.reset()
2695                    metrics.reset()
2696                    m = Model().eval() if eval_mode else Model()
2697                    self.common(m, (x,))
2698                    if simdlen != 1:
2699                        check_metrics_vec_kernel_count(8)
2700
2701    @requires_vectorization
2702    def test_transpose_copy(self):
2703        def fn(a):
2704            return a.t().contiguous()
2705
2706        for simdlen in (None, 256, 1):
2707            with config.patch({"cpp.simdlen": simdlen}):
2708                for dtype in (torch.float, torch.bfloat16):
2709                    for shape in (
2710                        (7, 7),
2711                        (8, 8),
2712                        (9, 9),
2713                        (16, 16),
2714                        (17, 17),
2715                        (32, 32),
2716                        (33, 33),
2717                    ):
2718                        torch._dynamo.reset()
2719                        metrics.reset()
2720                        x = torch.randn(shape, dtype=dtype)
2721                        self.common(fn, (x,))
2722                        if simdlen != 1:
2723                            check_metrics_vec_kernel_count(2)
2724
2725    @torch._dynamo.config.patch(specialize_int=False)
2726    def test_slice_scatter_issue122291(self):
2727        @torch.compile(fullgraph=True)
2728        def fn(t, t_src, dim, start, end, step):
2729            return t.slice_scatter(t_src, dim, start, end, step)
2730
2731        shape = ((16, 16), (16, 2), 1, 4, 10, 1)
2732        input_tensor = torch.zeros(shape[0], requires_grad=False, device="cpu")
2733        src_tensor = torch.ones(shape[1], requires_grad=False, device="cpu")
2734        with self.assertRaisesRegex(
2735            torch._dynamo.exc.BackendCompilerFailed, r".*shape error in scatter op"
2736        ):
2737            fn(input_tensor, src_tensor, shape[2], shape[3], shape[4], shape[5])
2738
2739    def test_horizontal_fusion(self):
2740        def fn(a, b, c, idx):
2741            _a = torch.index_select(a, dim=0, index=idx)
2742            _b = torch.index_select(b, dim=0, index=idx)
2743            _c = torch.index_select(c, dim=0, index=idx)
2744            return _a, _b, _c
2745
2746        with config.patch({"cpp.max_horizontal_fusion_size": 0}):
2747            metrics.reset()
2748            torch._dynamo.reset()
2749            a = torch.randn(size=(4, 16), dtype=torch.bfloat16)
2750            b = torch.randn(size=(4, 16), dtype=torch.bfloat16)
2751            c = torch.randn(size=(4, 16), dtype=torch.bfloat16)
2752            idx = torch.zeros(size=[4], dtype=torch.int64)
2753            opt_fn = torch._dynamo.optimize("inductor")(fn)
2754            opt_fn(a, b, c, idx)
2755            self.assertEqual(metrics.generated_kernel_count, 3)
2756            self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
2757
2758        with config.patch({"cpp.max_horizontal_fusion_size": 1}):
2759            metrics.reset()
2760            torch._dynamo.reset()
2761            a = torch.randn(size=(4, 32), dtype=torch.bfloat16)
2762            b = torch.randn(size=(4, 32), dtype=torch.bfloat16)
2763            c = torch.randn(size=(4, 32), dtype=torch.bfloat16)
2764            idx = torch.zeros(size=[4], dtype=torch.int64)
2765            opt_fn = torch._dynamo.optimize("inductor")(fn)
2766            opt_fn(a, b, c, idx)
2767            self.assertEqual(metrics.generated_kernel_count, 3)
2768            self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
2769
2770        with config.patch({"cpp.max_horizontal_fusion_size": 2}):
2771            metrics.reset()
2772            torch._dynamo.reset()
2773            a = torch.randn(size=(4, 64), dtype=torch.bfloat16)
2774            b = torch.randn(size=(4, 64), dtype=torch.bfloat16)
2775            c = torch.randn(size=(4, 64), dtype=torch.bfloat16)
2776            idx = torch.zeros(size=[4], dtype=torch.int64)
2777            opt_fn = torch._dynamo.optimize("inductor")(fn)
2778            opt_fn(a, b, c, idx)
2779            print(metrics.generated_kernel_count)
2780            self.assertEqual(metrics.generated_kernel_count, 2)
2781            self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
2782
2783        with config.patch({"cpp.max_horizontal_fusion_size": 3}):
2784            metrics.reset()
2785            torch._dynamo.reset()
2786            a = torch.randn(size=(4, 128), dtype=torch.bfloat16)
2787            b = torch.randn(size=(4, 128), dtype=torch.bfloat16)
2788            c = torch.randn(size=(4, 128), dtype=torch.bfloat16)
2789            idx = torch.zeros(size=[4], dtype=torch.int64)
2790            opt_fn = torch._dynamo.optimize("inductor")(fn)
2791            opt_fn(a, b, c, idx)
2792            self.assertEqual(metrics.generated_kernel_count, 1)
2793            self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
2794
2795    def test_lowp_fp_neg_abs(self):
2796        def fn(x):
2797            return x.neg().abs()
2798
2799        for dtype in _lowp_fp_dtypes:
2800            metrics.reset()
2801            x = torch.randn(100, 100).to(dtype)
2802            opt_fn = torch._dynamo.optimize("inductor")(fn)
2803            self.assertTrue(same(fn(x), opt_fn(x)))
2804            assert metrics.cpp_to_dtype_count == 0
2805            check_metrics_vec_kernel_count(1)
2806
2807    def test_transpose_non_contiguous(self):
2808        def fn(a):
2809            # From part of timm HaloAttn:
2810            # (https://github.com/rwightman/pytorch-image-models/blob/main/timm/layers/halo_attn.py#L97).
2811            # Fixed https://github.com/pytorch/pytorch/issues/94269 accuracy issue.
2812            as_strided = torch.ops.aten.as_strided.default(
2813                a, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680]
2814            )
2815            as_strided_1 = torch.ops.aten.as_strided.default(
2816                as_strided,
2817                [1, 384, 2, 2, 12, 12],
2818                [153600, 1, 61440, 3072, 7680, 384],
2819            )
2820            clone_1 = torch.ops.aten.clone.default(
2821                as_strided_1, memory_format=torch.contiguous_format
2822            )
2823            _unsafe_view_1 = torch.ops.aten._unsafe_view.default(
2824                clone_1, [8, 48, 4, 144]
2825            )
2826            permute_2 = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1])
2827            split_with_sizes = torch.ops.aten.split_with_sizes.default(
2828                permute_2, [16, 32], -1
2829            )
2830            getitem = split_with_sizes[0]
2831            getitem_1 = split_with_sizes[1]
2832            permute_3 = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2])
2833            expand_1 = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144])
2834            clone_3 = torch.ops.aten.clone.default(
2835                expand_1, memory_format=torch.contiguous_format
2836            )
2837            return clone_3
2838
2839        metrics.reset()
2840        x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last)
2841        self.common(fn, (x,))
2842        check_metrics_vec_kernel_count(1)
2843
2844    def test_non_contiguous_index_with_constant_stride(self):
2845        def fn(x):
2846            x1 = x[:, :, :, ::2]
2847            x2 = x[:, :, :, 1::2]
2848            x = torch.stack((-x2, x1), dim=-1)
2849            return x.flatten(-2)
2850
2851        metrics.reset()
2852        x = torch.randn(1, 32, 16, 68)
2853        opt_fn = torch._dynamo.optimize("inductor")(fn)
2854        _, code = run_and_get_cpp_code(opt_fn, x)
2855        self.assertTrue(same(fn(x), opt_fn(x)))
2856        # def and use
2857        FileCheck().check_count("cpp_fused", 2, exactly=True).run(code)
2858
2859    def test_invalid_index_of_empty_tensor(self):
2860        def fn(a):
2861            b = a[[0]]
2862            return b
2863
2864        a = torch.tensor([])
2865        with self.assertRaises(RuntimeError):
2866            torch.compile(fn)(a)
2867
2868    @torch.no_grad()
2869    @torch._inductor.config.patch(freezing=True)
2870    def test_issue122380(self):
2871        def func(x):
2872            t1 = torch.unbind(x)
2873            t2 = torch.stack(t1, dim=1)
2874            t3 = torch.tanh(t2)
2875            return t3
2876
2877        x = torch.randn(2, 3, 4)
2878        self.assertEqual(torch.compile(func)(x), func(x))
2879
2880    def test_ir_node_str(self):
2881        @torch.compile
2882        def fn(x: torch.Tensor) -> torch.Tensor:
2883            return x.sin(), torch.nn.Softmax(dim=1)(x.cos())
2884
2885        def run_node_alt(*args, **kwargs):
2886            rv = run_node(*args, **kwargs)
2887            strings.append(str(rv))
2888            return rv
2889
2890        strings = []
2891        run_node = GraphLowering.run_node
2892        with patch.object(GraphLowering, "run_node", run_node_alt):
2893            fn(torch.randn([8, 128]))
2894        self.assertGreater(len(strings), 3)
2895
2896    def test_vertical_sum_cpu_only(self):
2897        def fn1(a):
2898            return a.sum(dim=0)
2899
2900        def fn2(a):
2901            return a.sum(dim=1)
2902
2903        metrics.reset()
2904        x = torch.randn(100, 100)
2905        self.common(fn1, (x,))
2906        check_metrics_vec_kernel_count(1)
2907
2908        metrics.reset()
2909        x = torch.randn(100, 100, 100)
2910        self.common(fn2, (x,))
2911        check_metrics_vec_kernel_count(1)
2912
2913    def test_transpose_vertical_sum_cpu_only(self):
2914        def fn(a, b):
2915            c = a * b
2916            return c.sum(dim=1)
2917
2918        metrics.reset()
2919        x = torch.randn(100, 50, 50)
2920        y = torch.randn(100, 50, 50).transpose(1, 2)
2921        self.common(fn, (x, y))
2922        check_metrics_vec_kernel_count(2)
2923
2924    def test_transpose_mxn_16_16_bf16_fp16(self):
2925        def fn(a, b):
2926            c = a * b
2927            return c.sum(dim=1)
2928
2929        for dtype in [torch.bfloat16, torch.float16]:
2930            metrics.reset()
2931            x = torch.randn(100, 50, 50).to(dtype)
2932            y = torch.randn(100, 50, 50).to(dtype).transpose(1, 2)
2933            self.common(fn, (x, y))
2934            check_metrics_vec_kernel_count(2)
2935
2936    def test_transpose_mxn_32_32_bf16_fp16(self):
2937        def fn(a):
2938            return a.permute(0, 2, 1).contiguous()
2939
2940        for dtype in [torch.bfloat16, torch.float16]:
2941            metrics.reset()
2942            x = torch.randn(2, 9216, 9216).to(dtype)
2943            self.common(fn, (x,))
2944            check_metrics_vec_kernel_count(2)
2945
2946    def test_transpose_sum2d_cpu_only(self):
2947        def fn(a, b):
2948            c = a * b
2949            return c.sum()
2950
2951        metrics.reset()
2952        x = torch.randn(50, 50)
2953        y = torch.randn(50, 50).transpose(0, 1)
2954        self.common(fn, (x, y))
2955        check_metrics_vec_kernel_count(2)
2956
2957    def test_transpose_sum_outer(self):
2958        # https://github.com/pytorch/pytorch/issues/98573
2959        def fn(a):
2960            return a.transpose(2, 3).sum(dim=1).contiguous()
2961
2962        metrics.reset()
2963        x = torch.randn(10, 50, 50, 50)
2964        self.common(fn, (x,))
2965        check_metrics_vec_kernel_count(1)
2966
2967    def test_to_dtype_bool_float(self):
2968        # https://github.com/pytorch/pytorch/issues/100800
2969        def f(a):
2970            return torch.where(
2971                torch.ones_like(a).to(torch.bool),
2972                torch.zeros_like(a),
2973                torch.ones_like(a) * 2,
2974            )
2975
2976        self.common(f, (torch.ones(16),))
2977
2978    def test_to_dtype_float_bool(self):
2979        # https://github.com/pytorch/pytorch/issues/100466
2980        def f(a):
2981            a = a * torch.tensor(a >= 0, dtype=torch.float32)
2982            return a
2983
2984        x = torch.rand(16)
2985        self.common(f, (x,))
2986
2987    def test_constant_store(self):
2988        # https://github.com/pytorch/pytorch/issues/104515
2989        def f(a):
2990            a[0, [3, 3]] = -float("inf")
2991            return a
2992
2993        x = torch.rand(4, 5)
2994        self.common(f, (x,))
2995
2996    def test_to_channels_last_lowp_fp(self):
2997        def f(a):
2998            return a.to(memory_format=torch.channels_last)
2999
3000        for dtype in _lowp_fp_dtypes:
3001            x = torch.rand(2, 3, 14, 14).to(dtype)
3002            self.common(f, (x,))
3003
3004    def test_broadcast_mul_lowp_fp(self):
3005        def f(a, b):
3006            return a * b
3007
3008        for dtype in _lowp_fp_dtypes:
3009            a = torch.randn(2, 16, 16).to(dtype)
3010            b = torch.randn(2, 1, 1).to(dtype)
3011            self.common(f, (a, b))
3012
3013    def test_linear_buffer_reuse(self):
3014        class M(torch.nn.Module):
3015            def __init__(self):
3016                super().__init__()
3017                self.linear1 = torch.nn.Linear(16, 16)
3018                self.tanh = torch.nn.Tanh()
3019                self.linear2 = torch.nn.Linear(16, 16)
3020
3021            def forward(self, x):
3022                x = self.linear1(x)
3023                x = self.tanh(x)
3024                x = self.linear2(x)
3025                return x
3026
3027        mod = M().eval()
3028        v = torch.randn(1, 16)
3029
3030        with torch.no_grad():
3031
3032            def compile_fx_wrapper(model_, example_inputs_):
3033                return compile_fx(model_, example_inputs_)
3034
3035            def run(*ex, **kwargs):
3036                return mod(*ex, **kwargs)
3037
3038            run = torch._dynamo.optimize(compile_fx_wrapper)(run)
3039            _, code = run_and_get_cpp_code(run, v)
3040            self.assertFalse("= as_strided(" in code)
3041            self.assertEqual(run(*v), mod(*v))
3042
3043    def test_invalid_dropout_args(self):
3044        class MyModel(torch.nn.Module):
3045            def forward(self, x):
3046                x = x * 2
3047                x = torch.nn.functional.dropout(x, p=0.5)
3048                x = torch.relu(x)
3049                return x
3050
3051        example_inputs = torch.tensor([[1, 2, 3], [4, 5, 6]])
3052
3053        func = MyModel()
3054        jit_func = torch.compile(func)
3055        self.assertRaises(RuntimeError, lambda: func(example_inputs))
3056        self.assertRaises(RuntimeError, lambda: jit_func(example_inputs))
3057
3058    def test_nn_param_assign(self):
3059        # https://github.com/pytorch/pytorch/issues/99569
3060        class Model2(nn.Module):
3061            def __init__(self):
3062                super().__init__()
3063                self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
3064                self.batchnorm = nn.BatchNorm2d(num_features=5)
3065                self.conv_weight = torch.randn(5, 3, 3, 3)
3066                self.conv_bias = torch.randn(5)
3067
3068            def forward(self, x):
3069                self.conv.weight = nn.Parameter(self.conv_weight)
3070                self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False)
3071                self.conv.eval()
3072                x = self.conv(x)
3073                x = self.batchnorm(x)
3074                x = F.relu(x)
3075                return x
3076
3077        input_tensor = torch.randn(1, 3, 10, 10)
3078        func = Model2().to("cpu")
3079
3080        with torch.no_grad():
3081            func.train(False)
3082            v1 = func(input_tensor)
3083            jit_func = torch.compile(func, fullgraph=True)
3084            v2 = jit_func(input_tensor)
3085            self.assertEqual(v1, v2)
3086
3087    def test_nn_param_assign_wrapped(self):
3088        class Model2(nn.Module):
3089            def __init__(self):
3090                super().__init__()
3091                self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
3092                self.batchnorm = nn.BatchNorm2d(num_features=5)
3093                self.conv_weight = torch.randn(5, 3, 3, 3)
3094                self.conv_bias = torch.randn(5)
3095
3096            def forward(self, x):
3097                self.conv.weight = nn.Parameter(self.conv_weight)
3098                self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False)
3099                self.conv.eval()
3100                x = self.conv(x)
3101                x = self.batchnorm(x)
3102                x = F.relu(x)
3103                return x
3104
3105        input_tensor = torch.randn(1, 3, 10, 10)
3106        func = Model2().to("cpu")
3107
3108        @functools.wraps(func)
3109        def wrapper(*args, **kwargs):
3110            return func(*args, **kwargs)
3111
3112        with torch.no_grad():
3113            func.train(False)
3114            v1 = func(input_tensor)
3115            jit_func = torch.compile(wrapper, fullgraph=True)
3116            v2 = jit_func(input_tensor)
3117            self.assertEqual(v1, v2)
3118
3119    @config.patch(inplace_buffers=True)
3120    def test_in_out_buffer(self):
3121        def fn(x, y):
3122            z = torch.matmul(x, y.transpose(-1, -2)) / 8.0
3123            return z
3124
3125        inps = [torch.randn(1, 2, 8, 4), torch.randn(1, 2, 8, 4)]
3126        fn_opt = torch._dynamo.optimize("inductor")(fn)
3127        _, code = run_and_get_cpp_code(fn_opt, *inps)
3128        self.assertTrue("in_out_ptr" in code)
3129        self.assertEqual(fn_opt(*inps), fn(*inps))
3130
3131    def test_eliminate_meaningless_copy(self):
3132        def fn(x1, x2):
3133            permute = torch.ops.aten.permute.default(x2, [0, 2, 1, 3])
3134            clone = torch.ops.aten.clone.default(
3135                permute, memory_format=torch.contiguous_format
3136            )
3137            view = torch.ops.aten.view.default(clone, [1024, -1, 32])
3138            bmm = torch.ops.aten.bmm.default(view, x1)
3139            permute = torch.ops.aten.permute.default(view, [0, 2, 1])
3140            return (bmm, permute)
3141
3142        metrics.reset()
3143        self.common(
3144            fn,
3145            [
3146                rand_strided(
3147                    (1024, 32, 128), (4096, 1, 32), device="cpu", dtype=torch.float32
3148                ),
3149                rand_strided(
3150                    (64, 128, 16, 32),
3151                    (65536, 512, 32, 1),
3152                    device="cpu",
3153                    dtype=torch.float32,
3154                ),
3155            ],
3156        )
3157        self.assertEqual(metrics.generated_kernel_count, 1)
3158
3159    def test_attention_size_mismatch(self):
3160        class Attention(torch.nn.Module):
3161            def __init__(self, hidden_size, num_heads):
3162                super().__init__()
3163                self.hidden_size = hidden_size
3164                self.num_heads = num_heads
3165                self.head_size = hidden_size // num_heads
3166                self.query = torch.nn.Linear(hidden_size, hidden_size)
3167                self.key = torch.nn.Linear(hidden_size, hidden_size)
3168                self.value = torch.nn.Linear(hidden_size, hidden_size)
3169                self.inv_scale = torch.nn.Parameter(
3170                    torch.Tensor([1 / self.head_size**0.5]), requires_grad=False
3171                )
3172
3173            def forward(self, x):
3174                query = self.query(x)
3175                key = self.key(x)
3176                value = self.value(x)
3177                (batch_size, seq_len, hidden_size) = query.size()
3178                query = query.view(
3179                    batch_size, seq_len, self.num_heads, self.head_size
3180                ).permute(0, 2, 1, 3)
3181                key = key.view(
3182                    batch_size, seq_len, self.num_heads, self.head_size
3183                ).permute(0, 2, 3, 1)
3184                value = value.view(
3185                    batch_size, seq_len, self.num_heads, self.head_size
3186                ).permute(0, 2, 1, 3)
3187                attention_weights = (
3188                    torch.matmul(query, key).mul(self.inv_scale).softmax(dim=-1)
3189                )
3190                output = torch.matmul(attention_weights, value)
3191                return output
3192
3193        torch.manual_seed(123)
3194        hidden_size = 16
3195        num_heads = 1
3196        seq_len = 4
3197        batch_size = 1
3198        x = torch.randn(batch_size, seq_len, hidden_size)
3199
3200        func = Attention(hidden_size, num_heads).to("cpu")
3201
3202        with torch.no_grad():
3203            res1 = func(x)
3204            jit_func = torch.compile(func)
3205            res2 = jit_func(x)
3206        self.assertEqual(res1, res2)
3207
3208    def test_scalar_mul_bfloat16(self):
3209        def f(x):
3210            return torch.ops.aten.mul.Tensor(x, 1.7015043497085571)
3211
3212        metrics.reset()
3213        x = torch.randn(4, 5, dtype=torch.bfloat16)
3214        self.common(f, (x,))
3215        check_metrics_vec_kernel_count(1)
3216
3217    def test_bf16_zeros(self):
3218        def fn():
3219            x = torch.zeros(1, 1, 32, dtype=torch.bfloat16)
3220            return x
3221
3222        self.common(fn, ())
3223
3224    def test_select_tiliing_with_index_expr(self):
3225        def fn(x, y):
3226            x = torch.ops.aten.view.default(x, [8, 8, 8, 3136])
3227            x = torch.ops.aten.permute.default(x, [0, 1, 3, 2])
3228            y = torch.ops.aten.mul.Tensor(y, x)
3229            return torch.ops.aten.constant_pad_nd.default(y, [0, 0, 1, 0, 0, 0], 0.0)
3230
3231        x = torch.randn(8, 64, 56, 56)
3232        y = torch.randn(8, 8, 3136, 8)
3233        self.common(fn, (x, y))
3234
3235    @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
3236    @patch("torch.cuda.is_available", lambda: False)
3237    @config.patch(freezing=True)
3238    def test_linear_with_no_default_contiguous_input(self):
3239        dtypes = [
3240            torch.float32,
3241        ]
3242        if torch.ops.mkldnn._is_mkldnn_bf16_supported():
3243            dtypes.append(torch.bfloat16)
3244        if torch.ops.mkldnn._is_mkldnn_fp16_supported():
3245            dtypes.append(torch.float16)
3246        mod = torch.nn.Sequential(torch.nn.Linear(16, 16)).eval()
3247        temp = torch.randn(1, 16, 1, 1)
3248        v = torch.as_strided(temp, [1, 16], [0, 1], 0)
3249        self.assertTrue(v.is_contiguous())
3250        for dtype in dtypes:
3251            with torch.no_grad():
3252                self.common(
3253                    mod.to(dtype),
3254                    (v.to(dtype),),
3255                )
3256
3257    @patch("torch.cuda.is_available", lambda: False)
3258    @config.patch(freezing=True)
3259    def test_linear_with_reshape(self):
3260        class M(torch.nn.Module):
3261            def __init__(self):
3262                super().__init__()
3263                self.linear = torch.nn.Linear(16, 16, bias=False)
3264
3265            def forward(self, x):
3266                x = self.linear(x)
3267                return x.view(4, 4, 4)
3268
3269        mod = M().eval()
3270        v = torch.randn(4, 16)
3271        with torch.no_grad():
3272            torch._dynamo.reset()
3273            metrics.reset()
3274            self.common(
3275                mod,
3276                (v,),
3277            )
3278            assert metrics.generated_kernel_count == 0
3279
3280    @config.patch(implicit_fallbacks=True)
3281    def test_aten_normal_dtype(self):
3282        for dtype in [torch.float64, torch.float16, None]:
3283
3284            def fn():
3285                return torch.normal(2, 3, (10, 10), dtype=dtype, device="cpu")
3286
3287            self.assertEqual(
3288                torch.compile(fn, backend="aot_eager_decomp_partition")().dtype,
3289                dtype if dtype else torch.float32,
3290            )
3291            self.assertEqual(
3292                torch.compile(fn, backend="inductor")().dtype,
3293                dtype if dtype else torch.float32,
3294            )
3295
3296    def test_group_norm_vec(self):
3297        class M(torch.nn.Module):
3298            def __init__(self):
3299                super().__init__()
3300                self.group_norm = torch.nn.GroupNorm(32, 32)
3301
3302            def forward(self, x):
3303                return self.group_norm(x)
3304
3305        metrics.reset()
3306        mod = M().eval()
3307        x = torch.randn(2, 32, 32, 32)
3308        with torch.no_grad():
3309            self.common(mod, (x,))
3310            # 2 generated kernels (one for var_mean, the other for result)
3311            check_metrics_vec_kernel_count(2)
3312
3313    def test_int_div_vec(self):
3314        def fn(x, y, mode):
3315            return torch.div(x, y, rounding_mode=mode)
3316
3317        x = torch.randint(1, 100, (32, 32))
3318        y = torch.randint(1, 100, (32, 32))
3319        for mode in [None, "trunc", "floor"]:
3320            with torch.no_grad():
3321                metrics.reset()
3322                self.common(fn, (x, y, mode))
3323                check_metrics_vec_kernel_count(1)
3324
3325    def test_uint8_add(self):
3326        # https://github.com/pytorch/pytorch/issues/113016
3327        def fn(x, y):
3328            return torch.add(x, y).neg().to(torch.int32)
3329
3330        x = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
3331        y = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
3332        self.common(fn, (x, y))
3333
3334    def test_uint8_sub(self):
3335        # https://github.com/pytorch/pytorch/issues/113016
3336        def fn(x, y):
3337            return torch.sub(x, y).neg().to(torch.int32)
3338
3339        x = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
3340        y = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
3341        self.common(fn, (x, y))
3342
3343    def test_non_contiguous_reduction_store(self):
3344        # https://github.com/pytorch/pytorch/issues/113018
3345        class M(torch.nn.Module):
3346            def __init__(self):
3347                super().__init__()
3348                self.conv = torch.nn.Conv2d(39, 1, kernel_size=(1, 17), stride=(2, 2))
3349
3350            def forward(self, x):
3351                return self.conv(x.max(3).values)
3352
3353        m = M()
3354        x = torch.randn(1, 39, 1, 18, 17)
3355        self.common(m, (x,))
3356
3357    def test_embedding_vec(self):
3358        class M(torch.nn.Module):
3359            def __init__(self):
3360                super().__init__()
3361                self.emb = torch.nn.Embedding(64, 128)
3362
3363            def forward(self, idx, x):
3364                return self.emb(idx) + x
3365
3366        idx = torch.randint(0, 64, (4, 32))
3367        x = torch.randn(4, 32, 128)
3368        m = M().eval()
3369        with torch.no_grad():
3370            metrics.reset()
3371            self.common(m, (idx, x))
3372            check_metrics_vec_kernel_count(1)
3373
3374    def test_embedding_vec_bf16(self):
3375        class M(torch.nn.Module):
3376            def __init__(self):
3377                super().__init__()
3378                self.emb = torch.nn.Embedding(64, 128)
3379
3380            def forward(self, idx, x):
3381                return self.emb(idx)
3382
3383        idx = torch.randint(0, 64, (4, 32))
3384        x = torch.randn(4, 32, 128).to(torch.bfloat16)
3385        m = M().eval()
3386        with torch.no_grad():
3387            metrics.reset()
3388            self.common(m, (idx, x))
3389            check_metrics_vec_kernel_count(1)
3390
3391        # we are doing direct load/store, make sure we do not generate
3392        # redundant type casts
3393        m_opt = torch.compile(m)
3394        _, code = run_and_get_cpp_code(m_opt, idx, x)
3395        self.assertTrue("Vectorized" in code)
3396        self.assertTrue("cvt_lowp_fp_to_fp32" not in code)
3397        self.assertTrue("cvt_fp32_to_lowp_fp" not in code)
3398
3399    def test_concat_inner_vec(self):
3400        def fn(x, y):
3401            return F.relu(torch.cat([x, y], dim=1))
3402
3403        x = torch.randn(32, 35)
3404        y = torch.randn(32, 120)
3405        metrics.reset()
3406        self.common(fn, (x, y))
3407        check_metrics_vec_kernel_count(3)
3408
3409    def test_expr_vec_non_contiguous(self):
3410        def fn(x):
3411            # the pattern from sebotnet33ts_256
3412            y = torch.nn.functional.pad(x, (0, 31)).reshape(-1, 33, 63)
3413            y = y[:, :32, 31:].reshape(4, 32, 1, 32, 32).expand(-1, -1, 32, -1, -1)
3414            y = y.permute(0, 3, 1, 4, 2).clone(memory_format=torch.contiguous_format)
3415            y = y.view(4, 1024, 1024)
3416            return y.softmax(dim=-1)
3417
3418        x = torch.randn(128, 2048)
3419        opt_fn = torch.compile(fn)
3420        metrics.reset()
3421        _, code = run_and_get_cpp_code(opt_fn, x)
3422        self.assertTrue(same(fn(x), opt_fn(x)))
3423        # 4 kernels for max, exp, sum and div
3424        check_metrics_vec_kernel_count(4)
3425        FileCheck().check_count(
3426            "Vectorized<int>::loadu(tmpbuf.data())", 0, exactly=True
3427        ).run(code)
3428
3429    def test_vec_contiguous_ModularIndexing(self):
3430        # https://github.com/pytorch/pytorch/issues/114488
3431        class M(torch.nn.Module):
3432            def __init__(self, dim):
3433                super().__init__()
3434                self.norm = torch.nn.LayerNorm(dim * 4)
3435
3436            def forward(self, x):
3437                # the pattern from swin_base_patch4_window7_224
3438                B, H, W, C = x.shape
3439                x = (
3440                    x.reshape(B, H // 2, 2, W // 2, 2, C)
3441                    .permute(0, 1, 3, 4, 2, 5)
3442                    .flatten(3)
3443                )
3444                x = self.norm(x)
3445                return x
3446
3447        x = torch.randn(1, 56, 56, 128)
3448        m = M(128)
3449        opt_m = torch.compile(m)
3450        with torch.no_grad():
3451            metrics.reset()
3452            _, code = run_and_get_cpp_code(opt_m, x)
3453            self.assertTrue(same(m(x), opt_m(x)))
3454            # Two kernels: one for reduction, one pointwises
3455            check_metrics_vec_kernel_count(2)
3456            FileCheck().check_count(
3457                "Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
3458            ).run(code)
3459
3460    @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
3461    @parametrize("shape", ("15,3,13", "4,2048,4096"))
3462    def test_fp8_cast(self, dtype: torch.dtype, shape: str):
3463        def fp8_cast(x):
3464            y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
3465            y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
3466            return y0, y1
3467
3468        shape = [int(dim) for dim in shape.split(",")]
3469        x = torch.rand(*shape, device="cpu", dtype=dtype)
3470        self.common(fp8_cast, (x,))
3471
3472    def test_logical_op_store_to_lowp_data_dtype(self):
3473        # https://github.com/pytorch/pytorch/issues/117624
3474        # https://github.com/pytorch/pytorch/issues/117627
3475        def fn(out1, out2, input, other):
3476            o1 = torch.logical_or(out=out1, input=input, other=other)
3477            o2 = torch.logical_xor(out=out2, input=input, other=other)
3478            return o1, o2
3479
3480        x = torch.rand([3, 3, 2, 8, 9, 2], dtype=torch.float)
3481        y = torch.rand([3, 3, 2, 8, 9, 2], dtype=torch.float)
3482        for dtype in _lowp_fp_dtypes:
3483            o1 = torch.rand([3, 3, 2, 8, 9, 2], dtype=dtype)
3484            o2 = torch.rand([3, 3, 2, 8, 9, 2], dtype=dtype)
3485            with torch.no_grad():
3486                self.common(fn, (o1, o2, x, y))
3487
3488    def test_constant_bool_vec(self):
3489        def fn(x):
3490            mask = torch.zeros(1, dtype=torch.bool)
3491            return torch.where(mask, x, -1.0)
3492
3493        x = torch.rand(1000)
3494        metrics.reset()
3495        self.common(fn, (x,))
3496        check_metrics_vec_kernel_count(1)
3497
3498    @torch._dynamo.config.patch(dynamic_shapes=True)
3499    @torch._dynamo.config.patch(assume_static_by_default=False)
3500    def test_symbolic_shape_scalar_value_reduction(self):
3501        def fn(x, y):
3502            return y + torch.ones(x).sum()
3503
3504        with torch.no_grad():
3505            metrics.reset()
3506            y = torch.randn(100)
3507            self.common(fn, (100, y))
3508            check_metrics_vec_kernel_count(2)
3509
3510    def test_int32_pointwise_vec(self):
3511        def fn(x):
3512            return x * x
3513
3514        x = torch.randint(0, 100, (32, 32), dtype=torch.int32)
3515        metrics.reset()
3516        self.common(fn, (x,))
3517        check_metrics_vec_kernel_count(1)
3518
3519    def test_int32_reduction_vec(self):
3520        def fn(x):
3521            return x.sum(dim=1)
3522
3523        x = torch.randint(0, 100, (32, 32), dtype=torch.int32)
3524        metrics.reset()
3525        self.common(fn, (x,))
3526        check_metrics_vec_kernel_count(1)
3527
3528    def test_uint32_pointwise_vec(self):
3529        def fn(x):
3530            return x * x
3531
3532        x = torch.randint(0, 100, (32, 32), dtype=torch.uint32)
3533        metrics.reset()
3534        self.common(fn, (x,))
3535        # TODO(jgong5): change to 1 with vectorized uint32 load
3536        assert metrics.generated_cpp_vec_kernel_count == 0
3537
3538    def test_uint32_reduction_vec(self):
3539        def fn(x):
3540            return x.sum(dim=1)
3541
3542        x = torch.randint(0, 100, (32, 32), dtype=torch.uint32)
3543        metrics.reset()
3544        self.common(fn, (x,))
3545        # TODO(jgong5): change to 1 with vectorized uint32/uint64 load
3546        assert metrics.generated_cpp_vec_kernel_count == 0
3547
3548    def test_int64_pointwise_vec(self):
3549        def fn(x):
3550            return x * x
3551
3552        x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
3553        metrics.reset()
3554        self.common(fn, (x,))
3555        check_metrics_vec_kernel_count(1)
3556
3557    def test_int64_reduction_vec(self):
3558        def fn(x):
3559            return x.sum(dim=1)
3560
3561        x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
3562        metrics.reset()
3563        self.common(fn, (x,))
3564        check_metrics_vec_kernel_count(1)
3565
3566    def test_uint64_pointwise_vec(self):
3567        def fn(x):
3568            return x * x
3569
3570        x = torch.randint(0, 100, (32, 32), dtype=torch.uint64)
3571        metrics.reset()
3572        self.common(fn, (x,))
3573        # TODO(jgong5): change to 1 with vectorized uint64 load
3574        assert metrics.generated_cpp_vec_kernel_count == 0
3575
3576    def test_uint64_reduction_vec(self):
3577        def fn(x):
3578            return x.sum(dim=1)
3579
3580        x = torch.randint(0, 100, (32, 32), dtype=torch.uint64)
3581        metrics.reset()
3582        self.common(fn, (x,))
3583        # TODO(jgong5): change to 1 with vectorized uint64 load
3584        assert metrics.generated_cpp_vec_kernel_count == 0
3585
3586    def test_convert_int32_to_int64_vec(self):
3587        def fn(x):
3588            return x.to(torch.int64)
3589
3590        x = torch.randint(0, 100, (32, 32), dtype=torch.int32)
3591        metrics.reset()
3592        self.common(fn, (x,))
3593        check_metrics_vec_kernel_count(1)
3594
3595    def test_convert_int64_to_int32_vec(self):
3596        def fn(x):
3597            return x.to(torch.int32)
3598
3599        x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
3600        metrics.reset()
3601        self.common(fn, (x,))
3602        check_metrics_vec_kernel_count(1)
3603
3604    def test_convert_fp32_to_int64_vec(self):
3605        def fn(x):
3606            return x.to(torch.int64)
3607
3608        x = torch.rand(32, 32)
3609        metrics.reset()
3610        self.common(fn, (x,))
3611        check_metrics_vec_kernel_count(1)
3612
3613    def test_convert_int64_to_fp32_vec(self):
3614        def fn(x):
3615            return x.to(torch.float32)
3616
3617        x = torch.randint(0, 100, (32, 32), dtype=torch.int64)
3618        metrics.reset()
3619        self.common(fn, (x,))
3620        check_metrics_vec_kernel_count(1)
3621
3622    def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
3623        # https://github.com/pytorch/pytorch/issues/115260
3624        p0 = torch.tensor([1.0879], dtype=torch.float16)
3625
3626        class Model1(torch.nn.Module):
3627            def __init__(self):
3628                super().__init__()
3629
3630            def forward(self, *args):
3631                cat = torch.cat((args[3], args[2], args[1], args[0]), dim=2)
3632                max_1 = torch.max(args[4], p0)
3633                mul = torch.mul(cat, max_1)
3634                tan = torch.tan(mul)
3635                return (mul, tan)
3636
3637        metrics.reset()
3638        m = Model1()
3639        self.common(
3640            m,
3641            (
3642                torch.randn((17, 5, 1, 7)).half(),
3643                torch.randn((17, 5, 1, 7)).half(),
3644                torch.randn((17, 5, 11, 7)).half(),
3645                torch.randn((17, 5, 1, 7)).half(),
3646                torch.tensor(4.39, dtype=torch.float16),
3647            ),
3648        )
3649
3650    def test_masked_load_int64_vec(self):
3651        # https://github.com/pytorch/pytorch/issues/120377
3652        def fn(x):
3653            return torch.nn.functional.pad(x, (0, 13))
3654
3655        x = torch.randint(0, 100, (819,), dtype=torch.int64)
3656        metrics.reset()
3657        self.common(fn, (x,))
3658        assert metrics.generated_cpp_vec_kernel_count == 1
3659
3660    def test_reduction_float_to_int64(self):
3661        # https://github.com/pytorch/pytorch/issues/124821
3662        def fn(x):
3663            return x.max(0).values
3664
3665        x = torch.randint(0, 100, (22, 51), dtype=torch.int64)
3666        metrics.reset()
3667        self.common(fn, (x,))
3668        assert metrics.generated_cpp_vec_kernel_count == 1
3669
3670    @config.patch({"cpp.dynamic_threads": True})
3671    def test_reduction_with_dynamic_threads(self):
3672        def fn(a, b):
3673            return a.sum(), b.sum()
3674
3675        self.common(
3676            fn,
3677            (torch.randn(1000), torch.rand(1000)),
3678        )
3679
3680    @patch("torch.cuda.is_available", lambda: False)
3681    @config.patch(freezing=True)
3682    def test_linear_float64(self):
3683        class M(torch.nn.Module):
3684            def __init__(self):
3685                super().__init__()
3686                self.weight1 = torch.nn.Parameter(
3687                    torch.randn(10, 10, dtype=torch.float64)
3688                )
3689                self.weight2 = torch.nn.Parameter(
3690                    torch.randn(10, 10, dtype=torch.float64)
3691                )
3692                self.bias = torch.nn.Parameter(torch.randn(10, dtype=torch.float64))
3693
3694            def forward(self, x1):
3695                v1 = torch.mm(x1, self.weight1)
3696                v2 = torch.addmm(self.bias, x1, self.weight2)
3697                return (v1, v2)
3698
3699        mod = M().eval()
3700        v = torch.randn(10, 10, dtype=torch.float64)
3701        with torch.no_grad():
3702            self.common(
3703                mod,
3704                (v,),
3705            )
3706
3707    def test_fused_attention_conv(self):
3708        # https://github.com/pytorch/pytorch/issues/121174.
3709        class Model(torch.nn.Module):
3710            def __init__(self):
3711                super().__init__()
3712                self.q_conv = torch.nn.Conv2d(4, 4, 1)
3713                self.k_conv = torch.nn.Conv2d(4, 4, 1)
3714                self.v_conv = torch.nn.Conv2d(4, 4, 1)
3715
3716            def forward(self, x):
3717                q = self.q_conv(x)
3718                k = self.k_conv(x)
3719                v = self.v_conv(x)
3720                q = q.permute(0, 2, 1, 3)
3721                k = k.permute(0, 2, 1, 3)
3722                v = v.permute(0, 2, 1, 3)
3723                return torch.nn.functional.scaled_dot_product_attention(
3724                    q, k, v, dropout_p=0.0, is_causal=False
3725                )
3726
3727        fn = Model()
3728        x = torch.randn(1, 4, 2, 2)
3729        self.common(fn, (x,))
3730
3731    @requires_vectorization
3732    def test_vec_indirect_load_cse_cache(self):
3733        # https://github.com/pytorch/pytorch/issues/123502
3734        from math import inf
3735
3736        def fn(arg0_1):
3737            full_default = torch.ops.aten.full.default([209985], 1)
3738            select = torch.ops.aten.select.int(arg0_1, 0, 0)
3739            select_1 = torch.ops.aten.select.int(arg0_1, 0, 1)
3740            view = torch.ops.aten.reshape.default(select_1, [-1])
3741            expand = torch.ops.aten.expand.default(view, [209985])
3742            full_default_1 = torch.ops.aten.full.default([10000], 0)
3743            scatter_add = torch.ops.aten.scatter_add.default(
3744                full_default_1, 0, expand, full_default
3745            )
3746            pow_1 = torch.ops.aten.pow.Tensor_Scalar(scatter_add, -0.5)
3747            eq = torch.ops.aten.eq.Scalar(pow_1, inf)
3748            full_default_2 = torch.ops.aten.full.default([], 0.0)
3749            where = torch.ops.aten.where.self(eq, full_default_2, pow_1)
3750            index = torch.ops.aten.index.Tensor(where, [select])
3751            index_1 = torch.ops.aten.index.Tensor(where, [select_1])
3752            mul_1 = torch.ops.aten.mul.Tensor(index, index_1)
3753            return (mul_1,)
3754
3755        x = torch.zeros(2, 209985).to(torch.int64)
3756        opt_fn = torch._dynamo.optimize("inductor")(fn)
3757        _, code = run_and_get_cpp_code(opt_fn, x)
3758        FileCheck().check_count(
3759            "return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(),",
3760            4,
3761            exactly=True,
3762        ).run(code)
3763
3764
3765if __name__ == "__main__":
3766    from torch._inductor.test_case import run_tests
3767    from torch.testing._internal.inductor_utils import HAS_CPU
3768
3769    if HAS_CPU and not IS_MACOS:
3770        run_tests(needs="filelock")
3771