xref: /aosp_15_r20/external/pytorch/test/inductor/test_torchinductor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3import copy
4import dataclasses
5import functools
6import gc
7import importlib
8import itertools
9import math
10import operator
11import os
12import random
13import re
14import subprocess
15import sys
16import threading
17import time
18import typing
19import unittest
20import unittest.mock
21import weakref
22from pathlib import Path
23from typing import Tuple
24from unittest.mock import patch
25
26import numpy as np
27
28import torch
29
30import torch._dynamo.config as dynamo_config
31import torch.nn as nn
32from torch._dispatch.python import enable_python_dispatcher
33from torch._dynamo.debug_utils import aot_graph_input_parser
34from torch._dynamo.testing import (
35    CompileCounterWithBackend,
36    expectedFailureCodegenDynamic,
37    rand_strided,
38    same,
39    skipIfPy312,
40)
41from torch._dynamo.utils import ifdynstaticdefault
42from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext
43from torch._inductor.fx_passes import pad_mm
44from torch._inductor.test_case import TestCase as InductorTestCase
45from torch._inductor.utils import (
46    add_scheduler_init_hook,
47    aoti_compile_with_persistent_cache,
48    aoti_eager_cache_dir,
49    load_aoti_eager_cache,
50    run_and_get_code,
51    run_and_get_cpp_code,
52    run_and_get_triton_code,
53)
54from torch._inductor.virtualized import V
55from torch._prims_common import is_integer_dtype
56from torch.fx.experimental.proxy_tensor import make_fx
57from torch.library import _scoped_library
58from torch.nn import functional as F
59from torch.testing import FileCheck, make_tensor
60from torch.testing._internal.common_cuda import (
61    PLATFORM_SUPPORTS_FLASH_ATTENTION,
62    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
63    SM80OrLater,
64    TEST_CUDNN,
65    with_tf32_off,
66)
67
68from torch.testing._internal.common_device_type import (
69    _has_sufficient_memory,
70    expectedFailureXPU,
71)
72from torch.testing._internal.common_dtype import all_types, get_all_dtypes
73from torch.testing._internal.common_utils import (
74    DeterministicGuard,
75    instantiate_parametrized_tests,
76    IS_CI,
77    IS_FBCODE,
78    IS_MACOS,
79    IS_WINDOWS,
80    IS_X86,
81    parametrize,
82    serialTest,
83    skipIfNNModuleInlined,
84    skipIfRocm,
85    skipIfXpu,
86    subtest,
87    TEST_WITH_ASAN,
88    TEST_WITH_ROCM,
89)
90from torch.utils import _pytree as pytree
91from torch.utils._python_dispatch import TorchDispatchMode
92from torch.utils._pytree import tree_flatten, tree_unflatten
93from torch.utils.weak import WeakTensorKeyDictionary
94
95DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
96
97if IS_WINDOWS and IS_CI:
98    sys.stderr.write(
99        "Windows CI does not have necessary dependencies for test_torchinductor yet\n"
100    )
101    if __name__ == "__main__":
102        sys.exit(0)
103    raise unittest.SkipTest("requires sympy/functorch/filelock")
104
105importlib.import_module("functorch")
106importlib.import_module("filelock")
107
108from torch._inductor import config, test_operators
109
110from torch._inductor.compile_fx import (
111    compile_fx,
112    compile_fx_inner,
113    complex_memory_overlap,
114)
115from torch._inductor.utils import has_torchvision_roi_align
116
117from torch.testing._internal.common_utils import slowTest
118from torch.testing._internal.inductor_utils import (
119    GPU_TYPE,
120    HAS_CPU,
121    HAS_GPU,
122    HAS_MULTIGPU,
123    skipCPUIf,
124    skipCUDAIf,
125)
126
127HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
128
129aten = torch.ops.aten
130requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu")
131
132requires_multigpu = functools.partial(
133    unittest.skipIf, not HAS_MULTIGPU, f"requires multiple {GPU_TYPE} devices"
134)
135skip_if_x86_mac = functools.partial(
136    unittest.skipIf, IS_MACOS and IS_X86, "Does not work on x86 Mac"
137)
138vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
139
140libtest = torch.library.Library("test", "FRAGMENT")  # noqa: TOR901
141ids = set()
142
143f32 = torch.float32
144i64 = torch.int64
145i32 = torch.int32
146
147
148def _large_cumprod_input(shape, dim, dtype, device):
149    # Construct a cumprod input which guaruntees not to overflow or underflow
150    if is_integer_dtype(dtype):
151        # Large products don't fit in integers, the best we can do
152        # is random +/-1 values to test the sign of the result
153        x = torch.randint(0, 1, shape, dtype=dtype, device=device)
154        return x * 2 - 1
155
156    comp_dtype = torch._prims_common.get_computation_dtype(dtype)
157    batch_size = 256
158    if comp_dtype != dtype:
159        batch_size = math.floor(math.log2(torch.finfo(dtype).max) / 3)
160
161    # Create random values with a uniform magnitude and uniform exponent
162    num_batches = (shape[dim] + 2 * batch_size - 1) // (2 * batch_size)
163    batch_shape = (
164        shape[:dim]
165        + (
166            num_batches,
167            batch_size,
168        )
169        + shape[dim + 1 :]
170    )
171    magnitude = 1 + torch.rand(batch_shape, dtype=comp_dtype, device=device)
172    exponent = torch.randint(-1, 1, batch_shape, device=device).to(comp_dtype)
173    batch = magnitude * exponent.exp2()
174
175    # Alternate each batch of values with their reciprocals so the product
176    # never gets too far away from 1
177    t = torch.cat((batch, batch.reciprocal()), dim=dim + 1)
178    t = t.flatten(dim, dim + 1)
179    t = aten.slice(t, dim=dim, start=0, end=shape[dim])
180
181    # Randomize sign
182    sign = torch.randint(0, 1, shape, device=device) * 2 - 1
183    return (t * sign).to(dtype)
184
185
186def define_custom_op_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()):
187    global libtest
188    global ids
189    if id_ not in ids:
190        libtest.define(f"{id_}(Tensor self) -> Tensor", tags=tags)
191        libtest.impl(id_, fn_cpu, "CPU")
192        libtest.impl(id_, fn_cuda, "CUDA")
193        libtest.impl(id_, fn_xpu, "XPU")
194        libtest.impl(id_, fn_meta, "Meta")
195        ids.add(id_)
196
197
198def define_custom_op_2_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()):
199    global libtest
200    global ids
201    if id_ not in ids:
202        libtest.define(
203            f"{id_}(Tensor self, float scale) -> (Tensor, Tensor)", tags=tags
204        )
205        libtest.impl(id_, fn_cpu, "CPU")
206        libtest.impl(id_, fn_cuda, "CUDA")
207        libtest.impl(id_, fn_xpu, "XPU")
208        libtest.impl(id_, fn_meta, "Meta")
209        ids.add(id_)
210
211
212def define_custom_op_3_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()):
213    global libtest
214    global ids
215    if id_ not in ids:
216        libtest.define(f"{id_}(Tensor[] x) -> Tensor", tags=tags)
217        libtest.impl(id_, fn_cpu, "CPU")
218        libtest.impl(id_, fn_cuda, "CUDA")
219        libtest.impl(id_, fn_xpu, "XPU")
220        libtest.impl(id_, fn_meta, "Meta")
221        ids.add(id_)
222
223
224f32 = torch.float32
225
226
227def run_fw_bw_and_get_code(fn):
228    def run_with_backward():
229        result = fn()
230        result.sum().backward()
231        return result
232
233    return run_and_get_code(run_with_backward)
234
235
236def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl):
237    for _op_name in op_set:
238        qualified_op_name = f"{ns}::{_op_name}"
239        _, overload_names = torch._C._jit_get_operation(qualified_op_name)
240        for overload_name in overload_names:
241            try:
242                reg_op_name = qualified_op_name
243                schema = torch._C._get_schema(qualified_op_name, overload_name)
244                if schema.overload_name:
245                    reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
246                torch_compile_op_lib_impl._impl_with_aoti_compile(  # noqa: F821
247                    reg_op_name, dispatch_key
248                )
249            except Exception as e:
250                continue
251
252
253class TestCase(InductorTestCase):
254    @classmethod
255    def setUpClass(cls):
256        super().setUpClass()
257        cls._stack = contextlib.ExitStack()
258        cls._stack.enter_context(
259            config.patch(
260                {
261                    "debug": True,
262                    "debug_index_asserts": True,
263                    "cpp.min_chunk_size": 1,
264                    "triton.autotune_pointwise": False,  # too slow
265                    "implicit_fallbacks": False,
266                    "generate_intermediate_hooks": True,
267                }
268            )
269        )
270
271    @classmethod
272    def tearDownClass(cls):
273        cls._stack.close()
274        super().tearDownClass()
275
276    def setUp(self):
277        torch._dynamo.reset()
278        torch._inductor.metrics.reset()
279        super().setUp()
280        self._start = time.perf_counter()
281
282    def tearDown(self):
283        super().tearDown()
284        torch._dynamo.reset()
285        if os.environ.get("ERROR_ON_SLOW") == "1":
286            elapsed = time.perf_counter() - self._start
287            assert elapsed < 120
288
289
290class ToTuple(torch.nn.Module):
291    def forward(self, x):
292        return (x,)
293
294
295@dataclasses.dataclass
296class InputGen:
297    n: int
298    device: str
299
300    def dense(self):
301        return torch.randn((self.n, self.n), device=self.device)
302
303    def transposed(self):
304        return self.dense().transpose(0, 1)
305
306    def strided(self):
307        return torch.randn((self.n * 2, self.n * 3), device=self.device)[
308            self.n :, self.n :: 2
309        ]
310
311    def broadcast1(self):
312        return torch.randn((self.n,), device=self.device)
313
314    def broadcast2(self):
315        return torch.randn((1, self.n, 1), device=self.device)
316
317    def broadcast3(self):
318        return torch.randn((1,), device=self.device)
319
320    def double(self):
321        return torch.randn((self.n, self.n), device=self.device, dtype=torch.double)
322
323    def int(self):
324        return torch.arange(self.n, device=self.device, dtype=torch.int32)
325
326
327def compute_grads(args, kwrags, results, grads):
328    def gather_leaf_tensors(args, kwargs):
329        args = pytree.arg_tree_leaves(*args, **kwargs)
330        leaf_tensors = [
331            arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad
332        ]
333        return leaf_tensors
334
335    flat_results = pytree.tree_leaves(results)
336    flat_diff_results = [
337        r for r in flat_results if isinstance(r, torch.Tensor) and r.requires_grad
338    ]
339    assert len(flat_diff_results) > 0
340
341    leaf_tensors = gather_leaf_tensors(args, kwrags)
342    assert len(leaf_tensors) > 0
343    return torch.autograd.grad(
344        flat_diff_results,
345        leaf_tensors,
346        grads,
347        allow_unused=True,
348        retain_graph=True,
349    )
350
351
352def clone_preserve_strides(x, device=None):
353    if not isinstance(x, torch.Tensor):
354        return x
355    buffer = torch.as_strided(
356        x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
357    )
358    if not device:
359        buffer = buffer.clone()
360    else:
361        buffer = buffer.to(device, copy=True)
362    out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
363    return out
364
365
366def check_model(
367    self: TestCase,
368    model,
369    example_inputs,
370    kwargs=None,
371    *,
372    atol=None,
373    rtol=None,
374    grad_atol=None,
375    grad_rtol=None,
376    check_lowp=True,
377    exact_dtype=True,
378    nopython=True,
379    copy_to_gpu=True,
380    reference_in_float=True,
381    assert_equal=True,
382    check_gradient=False,
383    check_has_compiled=True,
384    output_process_fn_grad=lambda x: x,
385):
386    kwargs = kwargs or {}
387    torch._dynamo.reset()
388
389    ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
390    ref_kwargs = kwargs
391    has_lowp_args = False
392
393    if reference_in_float and exact_dtype:
394        # Store expected dtypes so we can check actual result gives the correct types
395        torch.manual_seed(0)
396        try:
397            eager_result = model(*ref_inputs, **ref_kwargs)
398        except RuntimeError:
399            # Eager model may fail if the dtype is not supported
400            eager_result = None
401
402        ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
403        expect_dtypes = [
404            x.dtype if isinstance(x, torch.Tensor) else None
405            for x in pytree.tree_leaves(eager_result)
406        ]
407        del eager_result
408
409    ref_model = model
410    if reference_in_float:
411        # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg
412        def upcast_fn(x):
413            nonlocal has_lowp_args
414            if isinstance(x, torch.Tensor) and (
415                x.dtype == torch.float16 or x.dtype == torch.bfloat16
416            ):
417                has_lowp_args = True
418                return x.float()
419            else:
420                return x
421
422        ref_inputs = list(map(upcast_fn, example_inputs))
423        ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()}
424        if has_lowp_args and hasattr(model, "to"):
425            ref_model = copy.deepcopy(model).to(torch.float)
426
427    torch.manual_seed(0)
428
429    correct = ref_model(*ref_inputs, **ref_kwargs)
430
431    torch._inductor.metrics.reset()
432
433    called = False
434
435    def compile_fx_wrapper(model_, example_inputs_):
436        nonlocal called
437        called = True
438        return compile_fx(model_, example_inputs_)
439
440    def run(*ex, **kwargs):
441        return model(*ex, **kwargs)
442
443    run = torch._dynamo.optimize(compile_fx_wrapper, nopython=nopython)(run)
444
445    torch.manual_seed(0)
446    actual = run(*example_inputs, **kwargs)
447    # if not called:
448    #     exp = torch._dynamo.explain(run)(*example_inputs)
449    #     print("Explain:", exp[0])
450    #     for graph in exp[2]:
451    #         print("Graph", graph)
452    if check_has_compiled:
453        assert called, "Ran graph without calling compile_fx"
454    assert type(actual) == type(correct)
455    if isinstance(actual, (tuple, list)):
456        assert len(actual) == len(correct)
457        assert all(
458            type(actual_item) == type(correct_item)
459            for actual_item, correct_item in zip(actual, correct)
460        )
461
462    correct_flat, correct_spec = tree_flatten(correct)
463    actual_flat = pytree.tree_leaves(actual)
464
465    def reference_to_expect(actual_flat, correct_flat):
466        return tuple(
467            (
468                y.to(x.dtype)
469                if isinstance(y, torch.Tensor) and y.dtype.is_floating_point
470                else y
471            )
472            for x, y in zip(actual_flat, correct_flat)
473        )
474
475    if reference_in_float and exact_dtype:
476        for expect_dtype, actual_result in zip(expect_dtypes, actual_flat):
477            if expect_dtype is not None:
478                assert (
479                    actual_result.dtype == expect_dtype
480                ), f"dtype mismatch, expected {expect_dtype} but got {actual_result.dtype}"
481
482    if reference_in_float:
483        correct_flat = reference_to_expect(actual_flat, correct_flat)
484        correct = tree_unflatten(correct_flat, correct_spec)
485
486    if assert_equal:
487        self.assertEqual(
488            actual,
489            correct,
490            atol=atol,
491            rtol=rtol,
492            equal_nan=True,
493            exact_dtype=exact_dtype,
494        )
495        # In case of input mutations, check that inputs are the same
496        self.assertEqual(
497            ref_inputs,
498            example_inputs,
499            atol=atol,
500            rtol=rtol,
501            equal_nan=True,
502            # our testing sometimes uses higher precision inputs for the reference
503            exact_dtype=False,
504        )
505    else:
506        for correct_val, actual_val in zip(correct_flat, actual_flat):
507            if isinstance(correct_val, torch.Tensor):
508                assert correct_val.device == actual_val.device
509                assert correct_val.size() == actual_val.size()
510                strides_equal, _ = torch._prims_common.check_significant_strides(
511                    correct_val, actual_val
512                )
513                assert strides_equal
514                assert correct_val.layout == actual_val.layout
515                if exact_dtype:
516                    assert correct_val.dtype == actual_val.dtype
517
518    if check_gradient:
519        actual = output_process_fn_grad(actual)
520        correct = output_process_fn_grad(correct)
521        actual_flat = pytree.tree_leaves(actual)
522        correct_flat = pytree.tree_leaves(correct)
523
524        # generate random unit norm gradients
525        grads = [
526            torch.rand(r.shape, device=r.device, dtype=r.dtype)
527            for r in correct_flat
528            if isinstance(r, torch.Tensor) and r.requires_grad
529        ]
530        for g in grads:
531            g /= g.norm()
532
533        correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads)
534        all_none_grads = all(x is None for x in correct_grad)
535        if all_none_grads:
536            # See Note [Detaching inputs that never need gradients]
537            # There are a handful of ops that can return None gradients, into of zero gradients.
538            # If all inputs to an AOTAutograd graph are supposed to get None gradients,
539            # AOTAutograd will end up forcing all of the outputs of the forward to not require grad.
540            # There's no easy fix to this (see the note above), although one option is to
541            # force any derivative formulas in core to return tensors of zeros instead of None.
542            flat_results = pytree.tree_leaves(actual)
543            results_that_require_grad = [
544                x
545                for x in flat_results
546                if isinstance(x, torch.Tensor) and x.requires_grad
547            ]
548            self.assertEqual(len(results_that_require_grad), 0)
549        else:
550            actual_grad = compute_grads(example_inputs, kwargs, actual, grads)
551
552            if reference_in_float:
553                expect_grad = reference_to_expect(actual_grad, correct_grad)
554            else:
555                expect_grad = correct_grad
556
557            self.assertEqual(
558                actual_grad,
559                expect_grad,
560                atol=grad_atol or atol,
561                rtol=grad_rtol or rtol,
562                equal_nan=True,
563                exact_dtype=exact_dtype,
564            )
565
566    torch._dynamo.reset()
567
568
569@torch._inductor.config.patch("triton.cudagraphs", False)
570def check_model_gpu(
571    self: TestCase,
572    model,
573    example_inputs,
574    kwargs=None,
575    *,
576    atol=None,
577    rtol=None,
578    grad_atol=None,
579    grad_rtol=None,
580    check_lowp=True,
581    exact_dtype=True,
582    nopython=True,
583    copy_to_gpu=True,
584    reference_in_float=True,
585    assert_equal=True,
586    check_gradient=False,
587    check_has_compiled=True,
588    output_process_fn_grad=lambda x: x,
589):
590    kwargs = kwargs or {}
591    if hasattr(model, "to"):
592        model = model.to(device=GPU_TYPE)
593
594    if copy_to_gpu:
595        example_inputs = tuple(
596            clone_preserve_strides(x, device=GPU_TYPE) for x in example_inputs
597        )
598
599    check_model(
600        self,
601        model,
602        example_inputs,
603        kwargs,
604        atol=atol,
605        rtol=rtol,
606        grad_atol=grad_atol,
607        grad_rtol=grad_rtol,
608        exact_dtype=exact_dtype,
609        nopython=nopython,
610        reference_in_float=reference_in_float,
611        assert_equal=assert_equal,
612        check_gradient=check_gradient,
613        check_has_compiled=check_has_compiled,
614        output_process_fn_grad=output_process_fn_grad,
615    )
616
617    if check_lowp:
618
619        def downcast_fn(x):
620            if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
621                return x
622            return torch.empty_strided(
623                x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half
624            ).copy_(x)
625
626        example_inputs = list(map(downcast_fn, example_inputs))
627        if hasattr(model, "to"):
628            model = model.to(torch.half)
629        if rtol is not None:
630            rtol = max(2e-3, rtol)
631        check_model(
632            self,
633            model,
634            example_inputs,
635            kwargs,
636            atol=atol,
637            rtol=rtol,
638            grad_atol=grad_atol,
639            grad_rtol=grad_rtol,
640            exact_dtype=exact_dtype,
641            nopython=nopython,
642            reference_in_float=reference_in_float,
643            assert_equal=assert_equal,
644            check_gradient=check_gradient,
645            check_has_compiled=check_has_compiled,
646            output_process_fn_grad=output_process_fn_grad,
647        )
648
649
650check_model_cuda = check_model_gpu
651
652
653def _run_and_assert_no_indirect_indexing(
654    test_case, func, *args, has_wrapping=None, has_assert=False, **kwargs
655):
656    result, source_codes = run_and_get_code(func, *args, **kwargs)
657
658    for code in source_codes:
659        for line in code.split("\n"):
660            stmt = None
661            # Find indexing expressions
662            if ".load(" in line:
663                stmt = line.split(".load")[-1]
664            elif "tl.store" in line:
665                stmt = line.split(".store")[-1]
666                stmt = ",".join(stmt.split(",")[:-2])  # Remove store value and mask
667            elif ".store" in line:
668                stmt = line.split(".store")[-1]
669            elif "[" in line:
670                stmt = line.split("[")[-1].split("]")[0]
671            if "tl.make_block_ptr(" in line:
672                continue
673
674            if stmt is None:
675                continue
676
677            # indirect indexing involves a `tmp` variable
678            test_case.assertTrue(
679                "tmp" not in stmt,
680                msg=f"Found indirect indexing in statement '{stmt}' from code:\n{code}",
681            )
682        if has_wrapping is not None:
683            test_case.assertTrue(
684                ("where" in code or "?" in code) is has_wrapping,
685                msg=f"Wanted {has_wrapping=} but got\n{code}",
686            )
687    test_case.assertTrue(
688        any(
689            ("device_assert" in code or "TORCH_CHECK" in code) is has_assert
690            for code in source_codes
691        )
692    )
693    return result
694
695
696def assertGeneratedKernelCountEqual(self: TestCase, expected: int):
697    if config.triton.multi_kernel:
698        # when multi_kernel is enabled, we generated both persistent reduction
699        # and non-persistent reduction kernels for the same node schedule.
700        # That will mess up with the kernel count. Just don't check it.
701        return
702    if config.cpp_wrapper:
703        expected *= 2
704    self.assertEqual(torch._inductor.metrics.generated_kernel_count, expected)
705
706
707class SweepInputs2:
708    input_gen_types1 = [
709        "dense",
710        "transposed",
711        "strided",
712        "broadcast1",
713        "broadcast2",
714        "broadcast3",
715        "double",
716        "int",
717    ]
718    input_gen_types2 = input_gen_types1
719    gen = None
720
721    @staticmethod
722    def kernel(a, b):
723        return (a + b,)
724
725    @classmethod
726    def gen_template(cls, name1, name2):
727        def test(self):
728            check_model(
729                self,
730                cls.kernel,
731                (
732                    getattr(cls.gen, name1)(),
733                    getattr(cls.gen, name2)(),
734                ),
735            )
736
737        test.__name__ = f"test_{cls.gen.device}_{name1}_{name2}"
738        setattr(cls, test.__name__, test)
739
740    @classmethod
741    def populate(cls):
742        for name1 in cls.input_gen_types1:
743            for name2 in cls.input_gen_types2:
744                cls.gen_template(name1, name2)
745
746
747@instantiate_parametrized_tests
748class CommonTemplate:
749    def test_bool(self):
750        def fn(a, b):
751            return (
752                a + b,
753                a * b,
754                a & b,
755                a | b,
756                a ^ b,
757                torch.logical_and(a, b),
758                torch.logical_or(a, b),
759                torch.logical_not(a),
760                torch.sign(b),
761            )
762
763        self.common(
764            fn,
765            (
766                torch.tensor([True, False, True, False]),
767                torch.tensor([False, False, True, True]),
768            ),
769        )
770
771    @skipCUDAIf(not SM80OrLater, "Requires sm80")
772    def test_eager_aoti_support_out(self):
773        ns = "aten"
774        op_name = "clamp"
775        dispatch_key = "CPU"
776        device = "cpu"
777        if self.device.lower() == "cuda":
778            dispatch_key = "CUDA"
779            device = "cuda"
780
781        inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0)
782        min_tensor = inp_tensor - 0.05
783        max_tensor = inp_tensor + 0.05
784        with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
785            ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
786                -1
787            )
788            ref_tensor = torch.clamp(
789                max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor
790            )
791
792            ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
793                -1
794            )
795            ref_tensor1 = torch.clamp(
796                max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor
797            )
798
799            register_ops_with_aoti_compile(
800                ns, [op_name], dispatch_key, torch_compile_op_lib_impl
801            )
802
803            res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
804                -1
805            )
806            res_tensor = torch.clamp(
807                max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor
808            )
809
810            self.assertEqual(ref_tensor, res_tensor)
811            self.assertEqual(ref_out_tensor, res_out_tensor)
812
813            res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
814                -1
815            )
816            res_tensor1 = torch.clamp(
817                max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor
818            )
819
820            self.assertEqual(ref_tensor1, res_tensor1)
821            self.assertEqual(ref_out_tensor1, res_out_tensor1)
822
823    @skipCUDAIf(not SM80OrLater, "Requires sm80")
824    def test_eager_aoti_cache_hit(self):
825        ns = "aten"
826        op_name = "abs"
827        dispatch_key = "CPU"
828        device = "cpu"
829        if self.device.lower() == "cuda":
830            dispatch_key = "CUDA"
831            device = "cuda"
832
833        input_tensor = torch.randn(128, dtype=torch.float, device=device)
834        kernel_lib_path = aoti_compile_with_persistent_cache(
835            ns,
836            op_name,
837            device,
838            False,
839            getattr(torch.ops.aten, op_name),
840            (input_tensor,),
841            {},
842        )
843        self.assertTrue(Path(kernel_lib_path).exists())
844
845        from unittest import mock
846
847        # Patch the aoti_compile_with_persistent_cache as None to ensure no new kernel is generated
848        with mock.patch(
849            "torch._inductor.utils.aoti_compile_with_persistent_cache", None
850        ):
851            with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
852                # Get ref result from eager
853                ref_value = getattr(torch.ops.aten, op_name)(input_tensor)
854
855                register_ops_with_aoti_compile(
856                    ns, [op_name], dispatch_key, torch_compile_op_lib_impl
857                )
858
859                # Invoke the pre-compiled kernel and get result.
860                res_value = getattr(torch.ops.aten, op_name)(input_tensor)
861
862                self.assertEqual(ref_value, res_value)
863
864    @skipCUDAIf(not SM80OrLater, "Requires sm80")
865    def test_eager_aoti_with_persistent_cache(self):
866        def fn(a):
867            return torch.abs(a)
868
869        ns = "aten"
870        op_name = "abs"
871
872        device = "cpu"
873        if self.device.lower() == "cuda":
874            device = "cuda"
875
876        input_tensor = torch.randn(128, dtype=torch.float, device=device)
877        kernel_lib_path = aoti_compile_with_persistent_cache(
878            ns,
879            op_name,
880            input_tensor.device.type,
881            False,
882            fn,
883            args=(input_tensor,),
884            kwargs={},
885        )
886        self.assertTrue(len(kernel_lib_path) > 0)
887
888        device_kernel_cache = aoti_eager_cache_dir(ns, device)
889        kernel_conf = device_kernel_cache / f"{op_name}.json"
890        self.assertTrue(kernel_conf.exists())
891
892        json_data = load_aoti_eager_cache("aten", "abs", input_tensor.device.type)
893        self.assertTrue(json_data is not None)
894        self.assertTrue(isinstance(json_data, list))
895        self.assertTrue(len(json_data) > 0)
896
897        op_info = json_data[0]
898        self.assertTrue(isinstance(op_info, dict))
899        self.assertTrue("meta_info" in op_info)
900        self.assertTrue("kernel_path" in op_info)
901        kernel_libs_abs_path = []
902        for item in json_data:
903            kernel_path = device_kernel_cache / item["kernel_path"]
904            kernel_libs_abs_path.append(kernel_path.as_posix())
905
906        self.assertTrue(kernel_lib_path in kernel_libs_abs_path)
907
908    @skipCUDAIf(not SM80OrLater, "Requires sm80")
909    def test_eager_aoti_with_scalar(self):
910        namespace_name = "aten"
911        op_name = "add"
912        op_overload_name = "Tensor"
913        op_name_with_overload = f"{op_name}.{op_overload_name}"
914
915        dispatch_key = "CPU"
916        device = torch.device("cpu")
917        if self.device.lower() == "cuda":
918            dispatch_key = "CUDA"
919            device = torch.device("cuda")
920
921        # Test the difference between scalar tensor and scalar
922        a = torch.scalar_tensor(1.0, device=device)
923        b = torch.scalar_tensor(2.0, device=device)
924
925        kernel_lib_path = aoti_compile_with_persistent_cache(
926            namespace_name,
927            op_name_with_overload,
928            a.device.type,
929            False,
930            torch.ops.aten.add,
931            args=(a, b),
932            kwargs={"alpha": 3.0},
933        )
934        self.assertTrue(Path(kernel_lib_path).exists())
935        device_kernel_cache = aoti_eager_cache_dir(namespace_name, device.type)
936        kernel_conf = device_kernel_cache / f"{op_name_with_overload}.json"
937        self.assertTrue(kernel_conf.exists())
938        json_data = load_aoti_eager_cache(
939            namespace_name, op_name_with_overload, a.device.type
940        )
941        op_info = json_data[0]
942        self.assertTrue(isinstance(op_info, dict))
943        self.assertTrue("meta_info" in op_info)
944        self.assertTrue(len(op_info["meta_info"]) == 3)
945        self.assertTrue(op_info["meta_info"][0]["sizes"] == [])
946        self.assertTrue(op_info["meta_info"][0]["strides"] == [])
947        # Scalar Tensor
948        self.assertTrue("scalar_value" not in op_info["meta_info"][0])
949        self.assertTrue(op_info["meta_info"][1]["sizes"] == [])
950        self.assertTrue(op_info["meta_info"][1]["strides"] == [])
951        # Scalar Tensor
952        self.assertTrue("scalar_value" not in op_info["meta_info"][1])
953        self.assertTrue(op_info["meta_info"][2]["sizes"] == [])
954        self.assertTrue(op_info["meta_info"][2]["strides"] == [])
955        # Scalar
956        self.assertTrue("scalar_value" in op_info["meta_info"][2])
957
958        with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
959            a = torch.randn(128, device=device)
960            b = torch.randn(128, device=device)
961
962            scalar_values = [1.0, 2.0, 3.0]
963            ref_values = []
964            for scalar_value in scalar_values:
965                ref_values.append(torch.add(a, b, alpha=scalar_value))
966
967            register_ops_with_aoti_compile(
968                namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl
969            )
970
971            res_values = []
972            for scalar_value in scalar_values:
973                res_values.append(torch.add(a, b, alpha=scalar_value))
974
975            self.assertEqual(len(ref_values), len(res_values))
976            self.assertEqual(ref_values, res_values)
977
978    @skipCUDAIf(not SM80OrLater, "Requires sm80")
979    def test_eager_aoti_override_registration(self):
980        namespace_name = "aten"
981        dispatch_key = "CPU"
982        device = torch.device("cpu")
983        if self.device.lower() == "cuda":
984            dispatch_key = "CUDA"
985            device = torch.device("cuda")
986
987        unary_op_set = ["abs", "acos"]
988
989        def fn(x, op_name=""):
990            return getattr(torch, op_name)(x)
991
992        # Invoke torch.compile directly to get referent results
993        x = torch.randn(3, 4, device=device)
994
995        ref_array = []
996        for unary_op_name in unary_op_set:
997            opt_fn = torch.compile(functools.partial(fn, op_name=unary_op_name))
998            ref = opt_fn(x)
999            ref_array.append(ref)
1000
1001        with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
1002            register_ops_with_aoti_compile(
1003                namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl
1004            )
1005
1006            res_array = []
1007            for unary_op_name in unary_op_set:
1008                res_array.append(getattr(torch, unary_op_name)(x))
1009
1010            for ref, res in zip(ref_array, res_array):
1011                self.assertEqual(ref, res)
1012
1013        a = torch.randn(128, device=device)
1014        min_tensor = torch.randn(128, device=device)
1015        max_tensor = min_tensor + 0.5
1016
1017        ref_with_min = torch.ops.aten.clamp(a, min_tensor)
1018        ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
1019
1020        with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
1021            register_ops_with_aoti_compile(
1022                namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl
1023            )
1024            res_with_min = torch.ops.aten.clamp(a, min_tensor)
1025            res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
1026            self.assertEqual(ref_with_min, res_with_min)
1027            self.assertEqual(ref_with_min_max, res_with_min_max)
1028
1029    def test_add_const_int(self):
1030        def fn(a):
1031            return (a + 1, torch.add(a, 1, alpha=2))
1032
1033        for dtype in [torch.float32, torch.int32, torch.int64]:
1034            self.common(fn, (torch.arange(32, dtype=dtype),))
1035
1036    def test_add_const_float(self):
1037        def fn(a):
1038            return (a + 1.5,)
1039
1040        self.common(fn, (torch.randn(32),))
1041
1042    def test_add_inplace_permuted(self):
1043        def fn(x, y):
1044            return x.add_(y)
1045
1046        x = torch.ones([2, 12, 13, 17]).transpose(1, 2)
1047        y = torch.randn([2, 13, 1, 17])
1048
1049        self.common(fn, (x, y))
1050
1051    def test_add_complex(self):
1052        def fn(a, b, alpha):
1053            return torch.add(a, b, alpha=alpha)
1054
1055        x = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
1056        y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
1057
1058        self.common(fn, (x, y, 2))
1059
1060    def test_add_complex3(self):
1061        # fix https://github.com/pytorch/pytorch/issues/115071
1062        @torch.compile
1063        def fn(*args):
1064            a = torch.neg(args[0])
1065            b = torch.add(args[0], args[0])
1066            return (a, b)
1067
1068        x = torch.randn(41, dtype=torch.complex64)
1069        y = x.clone()
1070        # should not inplace write to the input
1071        fn(x)
1072        self.assertEqual(x, y)
1073
1074    def test_add_complex4(self):
1075        @torch.compile
1076        def fn(a, b):
1077            c = a + b
1078            d = a + b
1079            return c + d
1080
1081        for dtype in [torch.complex32, torch.complex64, torch.complex128]:
1082            x = torch.tensor(
1083                [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1],
1084                dtype=dtype,
1085                device=self.device,
1086            )
1087            y = torch.tensor(
1088                [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1],
1089                dtype=dtype,
1090                device=self.device,
1091            )
1092            _, code = run_and_get_code(fn, x, y)
1093            self.assertEqual(
1094                " ".join(code).count(
1095                    "view_dtype" if config.cpp_wrapper else "aten.view"
1096                ),
1097                3,
1098            )
1099
1100    def test_concat_add_inplace(self):
1101        def fn(x, y, z):
1102            return torch.cat([x, y], dim=1).add_(z)
1103
1104        x = torch.randn([2, 12, 14, 14])
1105        y = torch.randn([2, 12, 14, 14])
1106        z = torch.randn([2, 24, 14, 14])
1107
1108        self.common(fn, (x, y, z))
1109
1110    def test_abs(self):
1111        def fn(a):
1112            return (a / (torch.abs(a) + 1),)
1113
1114        self.common(fn, (torch.randn(17),))
1115
1116    def test_angle(self):
1117        def fn(a, b, c):
1118            return torch.angle(a), torch.angle(b), torch.angle(c)
1119
1120        complex_input = torch.tensor(
1121            [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1, float("nan")]
1122        )
1123        real_input = torch.tensor([-1.0, 0.0, 1.0, float("nan")])
1124        interger_real_input = torch.tensor([-1, 0, 1])
1125        self.common(fn, (complex_input, real_input, interger_real_input))
1126
1127    def test_sgn(self):
1128        def fn(a):
1129            return torch.sgn(a), torch.sgn(a + 1) - 1
1130
1131        self.common(fn, [torch.linspace(-10, 10, 41)])
1132
1133    @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
1134    def test_scatter_bf16(self):
1135        def fn(inp, src, index):
1136            return inp.scatter_add(0, index, src)
1137
1138        for dtype in [torch.int64, torch.bool, torch.bfloat16]:
1139            self.common(
1140                fn,
1141                [
1142                    torch.zeros(3, 5, dtype=dtype),
1143                    torch.ones((2, 5), dtype=dtype),
1144                    torch.tensor([[0, 1, 2, 0, 0]]),
1145                ],
1146            )
1147
1148    def test_randn_generator(self):
1149        def fn(a, generator):
1150            return torch.randn([20, 20], generator=generator, device=a.device)
1151
1152        self.common(fn, (torch.linspace(-10, 10, 41), None), assert_equal=False)
1153
1154        # generator not yet supported in dynamo
1155        with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "Generator"):
1156            self.common(fn, (torch.linspace(-10, 10, 41), torch.Generator(self.device)))
1157
1158    def test_sgn_extremal(self):
1159        def fn(a):
1160            return (torch.sgn(a),)
1161
1162        self.common(fn, [torch.tensor([np.nan, np.inf, -np.inf, 0])])
1163
1164    def test_max_min(self):
1165        def fn(a, b):
1166            return (torch.maximum(a, b), torch.minimum(a, b))
1167
1168        self.common(fn, (torch.randn(8), torch.randn(8)))
1169        t1 = torch.randn(8)
1170        t1[0] = float("nan")
1171        t2 = torch.randn(8)
1172        t2[1] = float("nan")
1173        self.common(fn, (t1, t2))
1174
1175    def test_neg_max_uint8(self):
1176        # https://github.com/pytorch/pytorch/issues/93380
1177        def fn(a, b):
1178            c = torch.neg(a)
1179            return torch.maximum(b, c)
1180
1181        a = torch.randint(256, (1,), dtype=torch.uint8)
1182        b = torch.randint(256, (8390,), dtype=torch.uint8)
1183        self.common(fn, (a, b))
1184
1185    def test_compar(self):
1186        def fn(x):
1187            return x.gt(3.5), x.ge(3.5), x.eq(3.5), x.le(2.5), x.lt(3.5), x.ne(3.5)
1188
1189        a = torch.tensor([3])
1190        self.common(fn, (a,))
1191
1192    def test_horizonal_fusion1(self):
1193        def fn(a, b, c):
1194            return (a + b, a - c, b * c)
1195
1196        self.common(
1197            fn, (torch.randn(8, 16, 16), torch.randn(8, 16, 16), torch.randn(1, 16, 1))
1198        )
1199
1200    def test_horizonal_fusion2(self):
1201        def fn(a, b, c):
1202            return a + 1, b + 2, c + 3
1203
1204        self.common(fn, (torch.randn(8, 16, 8), torch.randn(8, 16), torch.randn(16, 8)))
1205
1206    def test_vertical_fusion1(self):
1207        def fn(sa, ct, p):
1208            # From torchbench.pyhpc_equation_of_state
1209            v17 = -3.087032500374211e-7
1210            v18 = -1.988366587925593e-8
1211            v19 = -1.061519070296458e-11
1212            v20 = 1.550932729220080e-10
1213            t15 = v19 * ct
1214            t19 = v17 + ct * (v18 + t15) + v20 * sa
1215            t20 = 1.0 / t19
1216            t128 = t19 * p
1217            return t20 + t128
1218
1219        self.common(
1220            fn,
1221            (
1222                torch.randn(204, 204, 26),
1223                torch.randn(204, 204, 26),
1224                torch.randn(26),
1225            ),
1226        )
1227        assertGeneratedKernelCountEqual(self, 1)
1228
1229    @config.patch({"fx_graph_cache": False})
1230    def test_forced_buffer_realize(self):
1231        # Test torch._test_inductor_realize forces a buffer to be realized
1232        def fn(a):
1233            b = test_operators.realize(a * 2)
1234            return (b * 2,)
1235
1236        self.common(fn, (torch.randn(10),))
1237        self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 2)
1238
1239    @config.patch({"fx_graph_cache": False})
1240    def test_scheduler_vertical_fusion1(self):
1241        realize = test_operators.realize
1242
1243        def fn(sa, ct, p):
1244            # From torchbench.pyhpc_equation_of_state
1245            v17 = -3.087032500374211e-7
1246            v18 = -1.988366587925593e-8
1247            v19 = -1.061519070296458e-11
1248            v20 = 1.550932729220080e-10
1249            t15 = realize(v19 * ct)
1250            t19 = realize(v17 + ct * (v18 + t15) + v20 * sa)
1251            t20 = realize(1.0 / t19)
1252            t128 = realize(t19 * p)
1253            return t20 + t128
1254
1255        self.common(
1256            fn,
1257            (
1258                torch.randn(204, 204, 26),
1259                torch.randn(204, 204, 26),
1260                torch.randn(26),
1261            ),
1262        )
1263        self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 5)
1264        assertGeneratedKernelCountEqual(self, 1 if self.device == GPU_TYPE else 2)
1265
1266    def test_index_propagation(self):
1267        def copy(x):
1268            i = torch.arange(x.size(0), device=x.device)
1269            return x[i]
1270
1271        x = torch.randn(8, device=self.device)
1272        copy_opt = torch._dynamo.optimize("inductor")(copy)
1273
1274        expect = copy(x)
1275        actual = _run_and_assert_no_indirect_indexing(self, copy_opt, x)
1276        self.assertEqual(expect, actual)
1277
1278    def test_index_propagation_flip(self):
1279        def flip(x):
1280            i = torch.arange(x.size(0) - 1, -1, -1, device=x.device)
1281            return x[i]
1282
1283        x = torch.randn(8, device=self.device)
1284        flip_opt = torch._dynamo.optimize("inductor")(flip)
1285
1286        expect = flip(x)
1287        actual = _run_and_assert_no_indirect_indexing(self, flip_opt, x)
1288        self.assertEqual(expect, actual)
1289
1290    def test_index_propagation_floordiv(self):
1291        def repeat_interleave(x, n):
1292            # e.g. x=[1, 2, 3], n=2 => returns [1, 1, 2, 2, 3, 3]
1293            i = torch.arange(x.shape[0] * n, device=x.device)
1294            return x[i // n]
1295
1296        x = torch.randn(8, 16, device=self.device)
1297        repeat_interleave_opt = torch._dynamo.optimize("inductor")(repeat_interleave)
1298        # With static shapes we can prove the bound, our dynamic shapes reasoning is not good enough
1299        has_assert = ifdynstaticdefault(False, True)
1300        # this should be collapsed to direct indexing
1301        actual = _run_and_assert_no_indirect_indexing(
1302            self, repeat_interleave_opt, x, 3, has_assert=has_assert
1303        )
1304        expect = torch.repeat_interleave(x, 3, dim=0)
1305        self.assertEqual(expect, actual)
1306        self.assertEqual(actual, repeat_interleave(x, 3))
1307
1308    def test_index_propagation_remainder(self):
1309        def repeat(x, n):
1310            # e.g. x=[1, 2, 3], n=2 => returns [1, 2, 3, 1, 2, 3]
1311            i = torch.arange(x.shape[0] * n, device=x.device)
1312            return x[i % x.shape[0]]
1313
1314        x = torch.randn(8, 16, device=self.device)
1315        repeat_opt = torch._dynamo.optimize("inductor")(repeat)
1316
1317        # With static shapes we can prove the bound, our dynamic shapes reasoning is not good enough
1318        has_assert = ifdynstaticdefault(False, True)
1319        # this should be collapsed to direct indexing
1320        actual = _run_and_assert_no_indirect_indexing(
1321            self, repeat_opt, x, 3, has_wrapping=False, has_assert=has_assert
1322        )
1323        expect = x.repeat(3, 1)
1324        self.assertEqual(expect, actual)
1325        self.assertEqual(actual, repeat(x, 3))
1326
1327    def test_index_propagation_abs(self):
1328        def reflection_pad_left(x, n):
1329            # e.g. x=[1, 2, 3], n=2 => returns [3, 2, 1, 2, 3]
1330            i = torch.arange(x.shape[0] + n, device=x.device)
1331            return x[(i - n).abs()]
1332
1333        x = torch.randn(8, device=self.device)
1334        opt_fn = torch._dynamo.optimize("inductor")(reflection_pad_left)
1335
1336        # With static shapes we can prove the bound, our dynamic shapes reasoning is not good enough
1337        has_assert = ifdynstaticdefault(False, True)
1338        # this should be collapsed to direct indexing
1339        actual = _run_and_assert_no_indirect_indexing(
1340            self, opt_fn, x, 3, has_wrapping=False, has_assert=has_assert
1341        )
1342        expect = reflection_pad_left(x, 3)
1343        self.assertEqual(expect, actual)
1344
1345    def test_index_propagation_device_assert_masked(self):
1346        def fn(a):
1347            idx = torch.arange(a.size(0), device=a.device)
1348            padded_idx = torch.constant_pad_nd(idx, (1050, 0))
1349            padded_idx = torch.where(padded_idx >= 0, padded_idx, padded_idx)
1350            return a[padded_idx]
1351
1352        self.common(fn, (torch.randn(1024),))
1353
1354    @skipIfRocm
1355    @config.patch(debug_index_asserts=False)
1356    def test_neg_index(self):
1357        def test(
1358            fn, inps, has_assert: bool, has_wrapping: bool, vectorize: bool = True
1359        ):
1360            fn_opt = torch.compile(fn)
1361            if self.device == "cpu":
1362                _, code = run_and_get_cpp_code(fn_opt, *inps)
1363                self.assertTrue(("?" in code or "blendv" in code) is has_wrapping)
1364                self.assertTrue(("TORCH_CHECK" in code) is has_assert)
1365                # Assert that we always vectorize the kernel regardless of wrapping / checks
1366                self.assertTrue(("loadu" in code) is vectorize)
1367            else:
1368                code = run_and_get_triton_code(fn_opt, *inps)
1369                self.assertTrue(("tl.where" in code) is has_wrapping)
1370                self.assertTrue(("device_assert" in code) is has_assert)
1371
1372        def indirect(a, b):
1373            return a[b - 1]
1374
1375        a = torch.rand(1024, device=self.device)
1376        b = torch.zeros(256, dtype=torch.long, device=self.device)
1377        test(indirect, (a, b), has_assert=True, has_wrapping=True)
1378
1379        def direct(x):
1380            return x[:, -1]
1381
1382        a = torch.rand(1, 64, 32, device=self.device)
1383        # Does not even generate a kernel as it's a view
1384        test(direct, (a,), has_assert=False, has_wrapping=False, vectorize=False)
1385
1386        def flip(a, b):
1387            return a[b]
1388
1389        a = torch.rand(1024, device=self.device)
1390        b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device=self.device)
1391        test(flip, (a, b), has_assert=True, has_wrapping=True)
1392
1393        # Constant propagate a constant that's negative
1394        def flip_with_index_constant(a):
1395            b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device=a.device)
1396            return a[b]
1397
1398        # Wrapping is constant-folded
1399        test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
1400
1401        # Operation where we can't prove that the index is always positive or negative
1402        def pos_and_neg(a):
1403            b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device=a.device)
1404            return a[b]
1405
1406        # It has wrapping but no assert
1407        test(pos_and_neg, (a,), has_assert=False, has_wrapping=True)
1408
1409        # We currently don't do constant propagation with float constants
1410        # We cannot prove this kind of asserts just with bounds. We would need
1411        # to lift IndexPropagation.shape_env to be accessible in all of Inductor
1412        def flip_with_index(a):
1413            b = 1.0 * torch.arange(
1414                start=-1, end=-a.numel() - 1, step=-1, device=a.device
1415            )
1416            b = b.int()
1417            return a[b]
1418
1419        test(
1420            flip_with_index,
1421            (a,),
1422            has_assert=ifdynstaticdefault(False, True),
1423            has_wrapping=False,
1424            vectorize=False,  # Constant propagation off -> indirect indexing -> no vec
1425        )
1426
1427        def unsafe_index(a, b):
1428            return aten._unsafe_index(a, (b,))
1429
1430        test(unsafe_index, (a, b), has_assert=False, has_wrapping=True)
1431
1432        def constant_propagation(a):
1433            b = torch.tensor([2], device=a.device)
1434            return a[b]
1435
1436        test(
1437            constant_propagation,
1438            (a,),
1439            has_assert=ifdynstaticdefault(False, True),
1440            has_wrapping=False,
1441            vectorize=False,  # There's no loop to vectorize!
1442        )
1443
1444        def constant_propagation_neg(a):
1445            b = torch.tensor([-2], device=a.device)
1446            return a[b]
1447
1448        # In symbolic shapes, we know that we can access -2, so no assert is necessary!
1449        test(
1450            constant_propagation_neg,
1451            (a,),
1452            has_assert=False,
1453            has_wrapping=False,
1454            vectorize=False,  # There's no loop to vectorize!
1455        )
1456
1457    def test_computed_buffer_inlining(self):
1458        def flip(x):
1459            idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device)
1460            return x[idx], idx
1461
1462        flip_opt = torch._dynamo.optimize("inductor")(flip)
1463        x = torch.randn(8, device=self.device)
1464
1465        expect = flip(x)
1466        actual = _run_and_assert_no_indirect_indexing(self, flip_opt, x)
1467        self.assertEqual(expect, actual)
1468
1469    def test_sum1(self):
1470        def fn(a, b):
1471            return ((a + b).sum(-1),)
1472
1473        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
1474
1475    def test_sum2(self):
1476        def fn(a, b):
1477            return ((a + b).sum([1, 2]), (a + b).sum(-1))
1478
1479        self.common(fn, (torch.randn(8, 9, 3, 21), torch.randn(8, 9, 3, 21)))
1480
1481    def test_sum3(self):
1482        def fn(a, b):
1483            r1 = a + b
1484            r2 = r1.sum(-1)
1485            r3 = torch.squeeze(b) + 10
1486            return (r1, r2, r3)
1487
1488        # Mismatched elements: 2 / 10 (20.0%)
1489        # Greatest absolute difference: 0.0029296875 at index (8,) (up to 1e-05 allowed)
1490        # Greatest relative difference: 0.0017482517482517483 at index (6,) (up to 0.001 allowed)
1491        self.common(fn, (torch.randn(10, 10), torch.randn(1, 10)), atol=1e-5, rtol=2e-3)
1492
1493    def test_sum4(self):
1494        def fn(a):
1495            b = a + 1
1496            c = b.sum(-1)
1497            d = c + 3
1498            e = d.sum(-1)
1499            f = e + 5
1500            return (f, e, d, c, b)
1501
1502        self.common(fn, (torch.randn(1, 16, 8, 8),))
1503
1504    def test_sum5(self):
1505        def fn(a):
1506            b = a + 1
1507            c = b.sum(-1)
1508            d = c + 3
1509            e = d.sum(-1)
1510            f = e + 5
1511            return (f,)
1512
1513        self.common(fn, (torch.randn(1, 17, 8, 9),))
1514
1515    def test_reduction1(self):
1516        def fn(a):
1517            return (a.sum(), a.max(), a.min(), a.argmax(), a.argmin())
1518
1519        self.common(fn, (torch.tensor([float("-inf"), 0.0, float("inf")]),))
1520
1521    @skip_if_x86_mac()
1522    def test_reduction2(self):
1523        def fn(a):
1524            # FIXME: a.argmax
1525            return (a.sum(), a.max(), a.min(), a.argmin())
1526
1527        self.common(fn, (torch.full((4,), float("inf")),))
1528
1529    @skip_if_x86_mac()
1530    def test_reduction3(self):
1531        def fn(a):
1532            # FIXME: a.argmin
1533            return (a.sum(), a.max(), a.min(), a.argmax())
1534
1535        self.common(fn, (torch.full((4,), float("-inf")),))
1536
1537    def test_reduction4(self):
1538        if self.device == "cpu":
1539            raise unittest.SkipTest("Non-deterministic CPU results")
1540
1541        def fn(a):
1542            return (a.argmax(-1), a.argmin(-1))
1543
1544        inputs = (torch.ones(128), torch.ones(4, 4, 1))
1545        for i in inputs:
1546            self.common(fn, (i,))
1547
1548    @config.patch(unroll_reductions_threshold=1)
1549    def test_reduction5(self):
1550        if self.device == "cpu":
1551            raise unittest.SkipTest("Non-deterministic CPU results")
1552
1553        def fn(a):
1554            return (a.sum(), a.max(), a.min(), a.argmax())
1555
1556        self.common(fn, (torch.full((4,), float("-inf")),))
1557
1558    def test_prod(self):
1559        def fn(a):
1560            return a.prod(0), a.prod(1), a.prod()
1561
1562        self.common(fn, (torch.rand((10, 10)),))
1563        self.common(fn, (torch.rand((1, 2050)),))
1564
1565    def test_unroll_small_reduction(self):
1566        def fn(x):
1567            val1, index1 = x.min(-1)
1568            val2, index2 = x.max(-1)
1569            return (
1570                val1,
1571                index1,
1572                val2,
1573                index2,
1574                x.sum(-1),
1575                (x > 1).any(-1),
1576                (x > 0).all(-1),
1577                x.argmin(-1),
1578                x.argmax(-1),
1579                x.amin(-1),
1580                x.amax(-1),
1581                x.aminmax(),
1582            )
1583
1584        with config.patch(unroll_reductions_threshold=8):
1585            # small sized reductions will get unrolled
1586            self.common(fn, (torch.randn(8, 3),))
1587        torch._dynamo.reset()
1588        with config.patch(unroll_reductions_threshold=1):
1589            # make sure things also work if they aren't unrolled
1590            self.common(fn, (torch.randn(8, 3),))
1591
1592    def test_multilayer_sum_low_prec(self):
1593        # fp16 nyi for cpu
1594        if self.device == "cpu":
1595            raise unittest.SkipTest(f"requires {GPU_TYPE}")
1596
1597        def fn(a):
1598            return torch.mean(a)
1599
1600        self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
1601
1602    def test_multilayer_prime_size(self):
1603        def fn(a):
1604            return torch.max(a), torch.sum(a)
1605
1606        # Requires masked loading for the intermediate reduction
1607        sample = torch.full((3999971,), 0, dtype=torch.int64)
1608        sample[-1] = 1
1609        self.common(fn, (sample,))
1610
1611    @skipCPUIf(IS_MACOS, "fails on macos")
1612    def test_multilayer_var(self):
1613        def fn(a):
1614            return torch.var(a)
1615
1616        self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),)))
1617        self.common(fn, ((torch.rand((14923), dtype=torch.float32),)))
1618
1619    @skipCPUIf(IS_MACOS, "fails on macos")
1620    def test_multilayer_var_lowp(self):
1621        def fn(a):
1622            return torch.var(a)
1623
1624        self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),))
1625        self.common(fn, (torch.rand((14923), dtype=torch.float16),))
1626
1627    def test_split_cumsum(self):
1628        def fn(a):
1629            return torch.cumsum(a, -1)
1630
1631        for dtype in get_all_dtypes(
1632            include_bfloat16=False,
1633            include_bool=True,
1634            include_complex=False,
1635            include_half=False,
1636        ):
1637            # Use low=0 since when the mean value is 0, cumsum at all points
1638            # tends towards zero which makes the relative error term blow up
1639            inp = make_tensor(10, 3, 352, 352, low=0, dtype=dtype, device=self.device)
1640            self.common(fn, (inp.view(-1),), rtol=1e-5, atol=1e-5, check_lowp=False)
1641            self.common(fn, (inp.view(10, -1),), rtol=1e-5, atol=1e-5, check_lowp=False)
1642
1643    @skipCUDAIf(not SM80OrLater, "Requires sm80")
1644    @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm")
1645    def test_split_cumsum_low_prec(self):
1646        if self.device == "cpu":
1647            raise unittest.SkipTest("ir.Scan nyi on CPU")
1648
1649        def fn(a):
1650            return torch.cumsum(a.view(-1), 0)
1651
1652        self.common(
1653            fn,
1654            (torch.rand((10, 3, 352, 352), dtype=torch.float16),),
1655            reference_in_float=True,
1656            check_lowp=False,
1657        )
1658
1659    def test_consecutive_split_cumsum(self):
1660        def fn(a, b):
1661            a = a.view(-1)
1662            b = b.view(-1)
1663            return torch.cumsum(a, 0) + torch.cumsum(b, 0)
1664
1665        a = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=self.device)
1666        b = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float64, device=self.device)
1667        self.common(fn, (a, b), rtol=1e-5, atol=1e-5, check_lowp=False)
1668
1669    def test_split_cumprod(self):
1670        def fn(a):
1671            return torch.cumprod(a, -1)
1672
1673        for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
1674            inp = _large_cumprod_input(
1675                (10, 10000), dim=1, dtype=dtype, device=self.device
1676            )
1677            self.common(fn, (inp,), atol=1e-5, rtol=1e-4, check_lowp=False)
1678
1679    @skipCUDAIf(not SM80OrLater, "Requires sm80")
1680    @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm")
1681    def test_split_cumprod_low_prec(self):
1682        if self.device == "cpu":
1683            raise unittest.SkipTest("ir.Scan nyi on CPU")
1684
1685        def fn(a):
1686            return torch.cumprod(a.view(-1), 0)
1687
1688        for dtype in [torch.float16, torch.bfloat16]:
1689            inp = _large_cumprod_input(
1690                (10, 10000), dim=1, dtype=dtype, device=self.device
1691            )
1692            self.common(
1693                fn,
1694                (inp,),
1695                reference_in_float=True,
1696                check_lowp=False,
1697            )
1698
1699    def test_consecutive_split_cumprod(self):
1700        def fn(a, b):
1701            return torch.cumprod(a, 0) + torch.cumprod(b, 0)
1702
1703        a = _large_cumprod_input(
1704            (10000,), dim=0, dtype=torch.float32, device=self.device
1705        )
1706        b = _large_cumprod_input(
1707            (10000,), dim=0, dtype=torch.float64, device=self.device
1708        )
1709        self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False)
1710
1711    @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
1712    def test_custom_scan_op(self):
1713        if self.device != "cuda":
1714            raise unittest.SkipTest("associative_scan only supported on GPU")
1715
1716        def sum_combine(a, b):
1717            return a + b
1718
1719        from torch._higher_order_ops.associative_scan import associative_scan
1720
1721        a = torch.randn(100, 100, device=self.device)
1722        expect = torch.cumsum(a, 0)
1723        actual = associative_scan(sum_combine, a, 0)
1724        self.assertEqual(expect, actual)
1725
1726        def logcumsum_combine(a, b):
1727            min_v = torch.minimum(a, b)
1728            max_v = torch.maximum(a, b)
1729            mask = (min_v != max_v) | ~min_v.isinf()
1730            return torch.where(mask, max_v + (min_v - max_v).exp().log1p(), a)
1731
1732        expect = torch.logcumsumexp(a, 0)
1733        actual = associative_scan(logcumsum_combine, a, 0)
1734        self.assertEqual(expect, actual)
1735
1736    def test_custom_scan_op_compiled(self):
1737        if self.device != "cuda":
1738            raise unittest.SkipTest("associative_scan only supported on GPU")
1739
1740        from torch._higher_order_ops.associative_scan import associative_scan
1741
1742        def sum_combine(a, b):
1743            return a + b
1744
1745        def fn(a, b, dim):
1746            diff = (a - b).abs()
1747            sad = associative_scan(sum_combine, diff, dim)
1748            return sad.sum(dim)
1749
1750        a = torch.randn(100, 100, device=self.device)
1751        b = torch.randn(100, 100, device=self.device)
1752        self.common(fn, (a, b, 0))
1753        cfn = torch.compile(fn)
1754        _, code = run_and_get_code(cfn, a, b, 0)
1755
1756        # Check everything is fused into a single kernel
1757        FileCheck().check_not("run(").check_regex(
1758            r"triton_.*\.run\(arg[01]_1, arg[12]_1, buf1,"
1759        ).check_not("run(").run(code[0])
1760
1761    @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
1762    def test_custom_scan_op_multi_input(self):
1763        if self.device != "cuda":
1764            raise unittest.SkipTest("associative_scan only supported on GPU")
1765
1766        def argmax_combine(a, b):
1767            a_value, a_index = a
1768            b_value, b_index = b
1769            mask = (a_value > b_value) | ((a_value == b_value) & (a_index > b_index))
1770            return (
1771                torch.where(mask, a_value, b_value),
1772                torch.where(mask, a_index, b_index),
1773            )
1774
1775        from torch._higher_order_ops.associative_scan import associative_scan
1776
1777        a = torch.randn(100, 100, device=self.device)
1778        expect = torch.cummax(a, 0)
1779
1780        idx = torch.arange(100, device=self.device).view(100, 1).expand(100, 100)
1781        actual = associative_scan(argmax_combine, (a, idx), 0)
1782        self.assertEqual(expect, actual)
1783
1784    def test_embedding_bag_byte_unpack(self):
1785        if self.device != "cpu":
1786            raise unittest.SkipTest(f"No {GPU_TYPE} implementation (it returns empty)")
1787
1788        def fn(a):
1789            return torch.ops.quantized.embedding_bag_byte_unpack(a)
1790
1791        M, N = 32, 64
1792        scales = torch.randn(M, 1).view(torch.uint8)
1793        offsets = torch.randn(M, 1).view(torch.uint8)
1794        data = torch.randint(0, 255, (M, N), dtype=torch.uint8)
1795        packed = torch.cat([data, scales, offsets], dim=-1)
1796        self.common(fn, [packed])
1797
1798    def test_expanded_reduction(self):
1799        def fn(x, y):
1800            z = x * y
1801            return z.sum((0, 1))
1802
1803        atol = None
1804        rtol = None
1805
1806        # By default, inductor generate non-persistent reduction kernels in this
1807        # case. But when multi-kernel is enabled, inductor will pick the faster
1808        # of persistent reduction and non-persistent-reduction kernel.
1809        # In this case, inductor picked the persistent-reduction kernel.
1810        # The persistent reduction kernel happens to need looser tolerance.
1811        if config.triton.multi_kernel:
1812            atol = 1e-5
1813            rtol = 1e-5
1814        self.common(
1815            fn, (torch.randn(2, 197, 256), torch.randn(2, 1, 256)), atol=atol, rtol=rtol
1816        )
1817
1818    def test_min_max_reduction(self):
1819        def fn(a, b):
1820            return (
1821                (a + b).max(),
1822                (a + b).min(),
1823                torch.amax(a + 1, keepdim=True),
1824                torch.amin(b + 1, keepdim=True),
1825            )
1826
1827        dtypes = [torch.float, torch.float16]
1828        if not (self.device == "cuda" and not SM80OrLater):
1829            dtypes += [torch.bfloat16]
1830        for dtype in dtypes:
1831            self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype)))
1832
1833    def test_min_max_reduction_nan(self):
1834        def fn(a):
1835            return (torch.max(a), torch.min(a))
1836
1837        t1 = torch.randn(32)
1838        t1[16] = float("nan")
1839        self.common(fn, (t1,))
1840
1841    def test_fmin_fmax(self):
1842        def fn(a, b):
1843            return (
1844                torch.fmin(a, b),
1845                torch.fmax(a, b),
1846                torch.fmax(a + 1, torch.tensor(0.0)),
1847            )
1848
1849        self.common(
1850            fn,
1851            (
1852                torch.tensor(
1853                    [-10.0, 10.0, float("nan"), float("nan"), float("nan"), 3, 4]
1854                ),
1855                torch.tensor(
1856                    [float("nan"), float("nan"), -10.0, 10.0, float("nan"), 4, 3]
1857                ),
1858            ),
1859        )
1860
1861    def test_sum_int(self):
1862        def fn(x):
1863            return 2 * x.sum(-1) + x.sum()
1864
1865        dtypes = torch.bool, torch.uint8, torch.int
1866        inps = [torch.randint(2, (64,), dtype=dtype) for dtype in dtypes]
1867        for i in inps:
1868            self.common(fn, (i,), check_lowp=False)
1869
1870    def test_sum_dtype(self):
1871        def fn(x):
1872            return x * x.sum(-1, dtype=torch.double) + x.sum(dtype=torch.double)
1873
1874        self.common(fn, (torch.ones(32, 32) * 70,))
1875
1876    def test_cumsum(self):
1877        def fn(x):
1878            return x.cumsum(0), x.cumsum(1)
1879
1880        # Persistent reductions
1881        self.common(fn, (torch.rand(16, 32),), check_lowp=True)
1882        self.common(fn, (torch.rand(20, 30),), check_lowp=True)
1883
1884        # Non-persistent reduction
1885        self.common(fn, (torch.rand(100, 4000),), check_lowp=True)
1886
1887    def test_cumsum_zero_dim(self):
1888        def fn(x):
1889            return x.cumsum(0), x.cumsum(-1)
1890
1891        a = torch.rand(())
1892        self.common(fn, (a,))
1893
1894    def test_cumsum_no_mask(self):
1895        def fn(x):
1896            return x.cumsum(-1)
1897
1898        # Persistent reduction
1899        a = torch.rand((1, 1024))
1900        self.common(fn, (a,), check_lowp=not TEST_WITH_ROCM)
1901
1902        # Non-persistent reduction
1903        b = torch.rand((1, 8192))
1904        self.common(fn, (b,), check_lowp=not TEST_WITH_ROCM)
1905
1906    def test_cumprod_zero_dim(self):
1907        def fn(x):
1908            return x.cumprod(0), x.cumprod(-1)
1909
1910        a = torch.rand(())
1911        self.common(fn, (a,))
1912
1913    def test_logcumsumexp(self):
1914        def fn(x):
1915            return x.logcumsumexp(0), x.logcumsumexp(1)
1916
1917        # Persistent reductions
1918        self.common(fn, (torch.rand(16, 32),), check_lowp=not TEST_WITH_ROCM)
1919        self.common(fn, (torch.rand(20, 30),), check_lowp=not TEST_WITH_ROCM)
1920
1921        # Non-persistent reduction
1922        self.common(fn, (torch.rand(100, 4000),), check_lowp=not TEST_WITH_ROCM)
1923
1924    def test_logcumsumexp_zero_dim(self):
1925        def fn(x):
1926            return x.logcumsumexp(0), x.logcumsumexp(-1)
1927
1928        a = torch.rand(())
1929        self.common(fn, (a,))
1930
1931    def test_clamp(self):
1932        def fn(a, b):
1933            return (a.clamp(-0.1, 0.1), b.clamp(0), torch.clamp(a + b, max=0))
1934
1935        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
1936
1937    def test_clamp_type_promotion(self):
1938        def fn(a):
1939            b = torch.tensor(1.0, dtype=torch.double, device=self.device)
1940            c = torch.full((4,), 2, device=self.device)
1941            return a.clamp(min=b, max=c)
1942
1943        self.common(fn, (torch.randint(4, (4,)),))
1944
1945    def test_dist(self):
1946        def fn(a, b):
1947            return (
1948                torch.dist(a, b),
1949                torch.dist(a, b, p=1.2),
1950            )
1951
1952        self.common(fn, (torch.randn(4, 4), torch.randn(4, 4)))
1953
1954    @skipCUDAIf(not SM80OrLater, "Requires sm80")
1955    def test_dist_bf16(self):
1956        def fn(a, b):
1957            return torch.dist(a.to(torch.bfloat16), b.to(torch.bfloat16))
1958
1959        self.common(fn, (torch.randn(4, 4), torch.randn(4, 4)))
1960
1961    def test_arange1(self):
1962        def fn(x):
1963            rng1 = torch.arange(8 * 8, dtype=torch.float32, device=x.device).view(8, 8)
1964            rng2 = torch.arange(10, 18, device=x.device)
1965            tmp = x * rng1
1966            return tmp, tmp + rng2
1967
1968        self.common(fn, (torch.randn(8, 8),))
1969
1970    def test_arange2(self):
1971        def fn(x):
1972            rng1 = torch.arange(8, device=x.device)
1973            return (x + rng1,)
1974
1975        self.common(fn, (torch.randint(4, (8, 8)),), check_lowp=False)
1976
1977    def test_arange3(self):
1978        def fn(x):
1979            return x + torch.ops.aten.arange.start_step(
1980                0, 53, 4, dtype=torch.int64, device=x.device
1981            )
1982
1983        self.common(fn, (torch.randn(14),))
1984
1985    def test_arange4(self):
1986        def fn(x):
1987            return x - torch.arange(512, -512, -1.0, device=x.device)
1988
1989        self.common(fn, (torch.randn(1024),))
1990
1991    def test_arange5(self):
1992        def fn(step, device):
1993            return torch.arange(512, -512, step, device=device)
1994
1995        compiled_fn = torch._dynamo.optimize()(fn)
1996
1997        # NOTE: use assertEqual to check dtypes which self.common doesn't do
1998        for step in (-1, -1.0):
1999            expect = fn(step, self.device)
2000            actual = compiled_fn(step, self.device)
2001            self.assertEqual(expect, actual)
2002        self.assertEqual(expect, actual)
2003
2004    def test_arange6(self):
2005        def fn(x):
2006            return torch.arange(0.1, 8.0001, 1, dtype=x.dtype, device=x.device)
2007
2008        # Test that float arguments are truncated to int when dtype is set explicitly
2009        make_arg = functools.partial(
2010            make_tensor, device=self.device, requires_grad=False
2011        )
2012        self.common(fn, (make_arg(1, dtype=torch.float32),))
2013        self.common(fn, (make_arg(1, dtype=torch.int64),))
2014
2015    def test_linspace1(self):
2016        def fn(x):
2017            return torch.linspace(0.125, 0.875, 7, device=x.device) + x
2018
2019        self.common(fn, (torch.randn(1, 7),))
2020
2021    def test_linspace2(self):
2022        def fn(x):
2023            return torch.linspace(0, 2, 1, device=x.device) + x
2024
2025        self.common(fn, (torch.randn(1, 1),))
2026
2027    def test_linspace3(self):
2028        def fn(x):
2029            return torch.linspace(0, 2, 0, device=x.device)
2030
2031        self.common(fn, (torch.Tensor([]),))
2032
2033    def test_tensor1(self):
2034        def fn(x):
2035            return torch.tensor([1], device=x.device) + x, torch.tensor(
2036                5, device=x.device
2037            )
2038
2039        self.common(fn, (torch.randn(10),))
2040
2041    def test_tensor2(self):
2042        def fn(x):
2043            return torch.tensor(list(range(2, 40, 2)), device=x.device) + x
2044
2045        self.common(fn, (torch.randn(1),))
2046
2047    def test_tensor3(self):
2048        def fn(x):
2049            return (
2050                torch.tensor([], device=x.device),
2051                torch.tensor([1, 2], device=x.device) + 1,
2052                torch.tensor([1, 2, 3], device=x.device) + 2,
2053                torch.tensor([1, 2, 3, 4], device=x.device) + x,
2054            )
2055
2056        self.common(fn, [torch.randn(4)])
2057
2058    def test_views1(self):
2059        def fn1(x, y):
2060            return (x.view(size2) + y,)
2061
2062        def fn2(x, y):
2063            return ((x + 1).view(size2) + y,)
2064
2065        views = [
2066            ([5 * 7], [5, 7]),
2067            ([2 * 3 * 4 * 5 * 6 * 7], [2, 3, 4, 5, 6, 7]),
2068            ([2 * 3, 4, 5, 6 * 7], [2, 3, 4, 5, 6, 7]),
2069            ([10 * 5, 20], [10, 5, 20]),
2070            ([1, 10, 1], [10]),
2071            ([10, 1, 10, 1, 10], [10, 100]),
2072            ([2, 2, 2, 2], [4, 4]),
2073        ]
2074        for size1, size2 in views:
2075            self.common(fn1, (torch.randn(size1), torch.randn(size2)))
2076            self.common(fn2, (torch.randn(size1), torch.randn(size2)))
2077
2078        for size2, size1 in views:
2079            self.common(fn1, (torch.randn(size1), torch.randn(size2)))
2080            self.common(fn2, (torch.randn(size1), torch.randn(size2)))
2081
2082    def test_views2(self):
2083        def fn1(x):
2084            return (x.view(size2) + 1,)
2085
2086        def fn2(x):
2087            return ((x * 2).view(size2) + 1,)
2088
2089        for size1, size2 in [
2090            ([2, 2, 2, 2], [4, -1]),
2091            ([10, 1, 10, 1, 10], [-1, 100]),
2092            ([10 * 5, 20], [10, -1, 20]),
2093        ]:
2094            self.common(fn1, (torch.randn(size1),))
2095            self.common(fn2, (torch.randn(size1),))
2096
2097    def test_views3(self):
2098        # example taken from hf_BigBird
2099        def forward(arg1, arg2):
2100            index = torch.ops.aten.index(arg1, [arg2])
2101            view_1 = torch.ops.aten.view(index, [1, 2232, 64])
2102            view_2 = torch.ops.aten.view(view_1, [1, 12, 62, 192])
2103            return view_2
2104
2105        self.common(
2106            forward,
2107            (
2108                rand_strided((64, 64), (64, 1), torch.float32),
2109                rand_strided((2232,), (1,), torch.int64),
2110            ),
2111        )
2112
2113    def test_views4(self):
2114        # example taken from hf_BigBird
2115        def forward(arg1, arg2):
2116            arg1 = arg1.index_select(0, arg2)
2117            arg1 = torch.ops.aten.view(arg1, [2, 3, 4, 5, 5])
2118            arg1 = torch.ops.aten.view(arg1, [2, 3, 2, 10, -1])
2119            return arg1
2120
2121        self.common(
2122            forward,
2123            (
2124                torch.randn(12, 5, 5),
2125                torch.randint(0, 11, (24,)),
2126            ),
2127        )
2128
2129    def test_views5(self):
2130        # tensor with shape 0 in any dimension
2131        def forward(x):
2132            y = x[:, 4:]
2133            return y.view(len(y), -1, 4)
2134
2135        self.common(
2136            forward,
2137            (torch.randn(4, 4, 4, 4),),
2138        )
2139
2140    def test_views6(self):
2141        def forward(x):
2142            x = torch.ops.aten.relu(x)
2143            s = torch.ops.aten.slice(x, 0, 0, 9223372036854775807)
2144            s = torch.ops.aten.slice(s, 1, 0, 9223372036854775807)
2145            s = torch.ops.aten.slice(s, 3, 0, 0)
2146            y = torch.ops.aten.view(s, [4, 2, -1])
2147            return y
2148
2149        self.common(
2150            forward,
2151            (torch.randn(4, 2, 4, 4),),
2152        )
2153
2154    def test_views7(self):
2155        # x.view(dtype)
2156        def forward(x, y):
2157            x = (x + 1).to(torch.float32)
2158            y = (y + 1).to(torch.int32)
2159            return x.view(torch.int32), y.view(torch.float32)
2160
2161        self.common(
2162            forward,
2163            (
2164                torch.rand(2, 3, dtype=torch.float32),
2165                torch.randint(10, (2, 3), dtype=torch.int32),
2166            ),
2167        )
2168
2169    def test_relu(self):
2170        def fn(a, b):
2171            return (torch.relu(a), torch.relu(a + b) / 10)
2172
2173        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2174
2175    def test_exp(self):
2176        def fn(a, b):
2177            return (torch.exp(a), torch.exp(a + b))
2178
2179        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2180
2181    def test_exp2(self):
2182        def fn(a, b):
2183            return (torch.exp2(a), torch.exp2(a + b), torch.pow(2, -torch.abs(a - b)))
2184
2185        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2186
2187    def test_sigmoid(self):
2188        def fn(a, b):
2189            return (torch.sigmoid(a), torch.sigmoid(a + b))
2190
2191        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2192
2193    def test_round(self):
2194        def fn(a, b):
2195            return torch.round(a), torch.round(b + 1), torch.round(a, decimals=2)
2196
2197        # without manual_seed, there is some chance this test fails due to:
2198        # https://github.com/openai/triton/issues/530
2199        torch.manual_seed(0)
2200
2201        # with *100 we are always getting a number exactly at .5 which we don't do right in half
2202        self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 10))
2203
2204    def test_round_correctness(self):
2205        if self.device == "cuda":
2206            raise unittest.SkipTest("need to debug tl.libdevice on A100/V100")
2207
2208        def fn(a):
2209            return torch.round(a)
2210
2211        self.common(
2212            fn,
2213            [torch.arange(-10, 10, 0.1, dtype=torch.float64)],
2214            check_lowp=False,
2215        )
2216
2217    def test_builtins_round(self):
2218        def fn(x, i):
2219            return x[: round(i / 2 + 1)] + round(i / 2)
2220
2221        cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
2222
2223        x = torch.zeros(5, dtype=torch.int, device=self.device)
2224        with torch.no_grad():
2225            for i in range(1, 6):
2226                self.assertEqual(cfn(x, i), fn(x, i))
2227
2228    def test_builtins_round_float_ndigits_pos(self):
2229        def fn(x, i):
2230            return x + round(i / 2 * 123.4567, 1)
2231
2232        cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
2233
2234        x = torch.zeros(2, device=self.device)
2235        i = 2
2236
2237        with torch.no_grad():
2238            self.assertEqual(cfn(x, i), fn(x, i))
2239
2240    def test_builtins_round_float_ndigits_zero(self):
2241        def fn(x, i):
2242            return x + round(i / 2 * 123.4567, 0)
2243
2244        cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
2245
2246        x = torch.zeros(2, device=self.device)
2247        i = 2
2248
2249        with torch.no_grad():
2250            self.assertEqual(cfn(x, i), fn(x, i))
2251
2252    def test_builtins_round_float_ndigits_neg(self):
2253        def fn(x, i):
2254            return x + round(i / 2 * 123.4567, -1)
2255
2256        cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
2257
2258        x = torch.zeros(2, device=self.device)
2259        i = 2
2260
2261        with torch.no_grad():
2262            self.assertEqual(cfn(x, i), fn(x, i))
2263
2264    def test_builtins_round_int_ndigits_pos(self):
2265        def fn(x, i):
2266            return x + round(i, 1)
2267
2268        cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
2269
2270        x = torch.zeros(2, device=self.device)
2271        i = 123
2272
2273        with torch.no_grad():
2274            self.assertEqual(cfn(x, i), fn(x, i))
2275
2276    def test_builtins_round_int_ndigits_zero(self):
2277        def fn(x, i):
2278            return x + round(i, 0)
2279
2280        cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
2281
2282        x = torch.zeros(2, device=self.device)
2283        i = 123
2284
2285        with torch.no_grad():
2286            self.assertEqual(cfn(x, i), fn(x, i))
2287
2288    def test_silu(self):
2289        def fn(a):
2290            return (torch.nn.functional.silu(a),)
2291
2292        self.common(fn, (torch.randn(8, 8),))
2293
2294    def test_nan_to_num(self):
2295        def fn(a):
2296            return (
2297                torch.nan_to_num(a),
2298                torch.nan_to_num(a, nan=3.0),
2299                torch.nan_to_num(a, nan=None),
2300                torch.nan_to_num(a, posinf=4.0),
2301                torch.nan_to_num(a, neginf=5.0),
2302                torch.nan_to_num(a, nan=3.0, posinf=4.0, neginf=5.0),
2303            )
2304
2305        self.common(
2306            fn,
2307            (torch.tensor((float("nan"), float("inf"), float("-inf"), 1.0)),),
2308            check_lowp=False,  # a much more elaborate test is required to match finfo max's for float and half
2309        )
2310
2311    def test_one_hot(self):
2312        def fn(a):
2313            return torch.nn.functional.one_hot(a, 8) + 1
2314
2315        self.common(
2316            fn,
2317            (torch.arange(100).view(4, 5, 5) % 8,),
2318            check_lowp=False,
2319        )
2320
2321    def test_div1(self):
2322        def fn(a, b):
2323            return (
2324                aten.div(a, b, rounding_mode=None),
2325                aten.div(a, b, rounding_mode="floor"),
2326                aten.div(a, b, rounding_mode="trunc"),
2327                a / b,
2328                a // b,
2329            )
2330
2331        self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 100))
2332
2333    def test_div2(self):
2334        def fn(a, b):
2335            return (
2336                aten.div(a, b, rounding_mode=None),
2337                aten.div(a, b, rounding_mode="floor"),
2338                aten.div(a, b, rounding_mode="trunc"),
2339                a / b,
2340                a // b,
2341            )
2342
2343        self.common(fn, (torch.randint(-100, 100, [8, 8]), 100 * torch.randn(8, 8)))
2344
2345    def test_div3(self):
2346        def fn(a, b):
2347            return (
2348                aten.div(a, b, rounding_mode=None),
2349                aten.div(a, b, rounding_mode="floor"),
2350                aten.div(a, b, rounding_mode="trunc"),
2351                a / b,
2352                a // b,
2353            )
2354
2355        a = torch.randint(1, 100, [8, 8])
2356        self.common(fn, (a * 2, a))
2357
2358    def test_div4(self):
2359        def fn(a, b):
2360            return (
2361                aten.div(a, b, rounding_mode=None),
2362                aten.div(a, b, rounding_mode="floor"),
2363                aten.div(a, b, rounding_mode="trunc"),
2364                a / b,
2365                a // b,
2366            )
2367
2368        self.common(
2369            fn,
2370            (torch.randint(-100, 0, [8, 8]), torch.randint(1, 10, [8, 8])),
2371        )
2372
2373    def test_div5(self):
2374        def fn(a, b):
2375            return (
2376                aten.div(a, b, rounding_mode=None),
2377                aten.div(a, b, rounding_mode="floor"),
2378                aten.div(a, b, rounding_mode="trunc"),
2379                a / b,
2380                a // b,
2381            )
2382
2383        # divide a scalar
2384        self.common(fn, (torch.randint(-100, 0, [8, 8]), 16))
2385
2386    def test_div6(self):
2387        def fn(a, b):
2388            return (
2389                aten.div(a, b, rounding_mode=None),
2390                aten.div(a, b, rounding_mode="floor"),
2391                aten.div(a, b, rounding_mode="trunc"),
2392                a / b,
2393                a // b,
2394            )
2395
2396        # treat boolean as integer
2397        self.common(
2398            fn,
2399            (torch.ones([8, 8], dtype=torch.bool), torch.randint(-100, -1, [8, 8])),
2400        )
2401
2402    def test_div7(self):
2403        def fn(a, b):
2404            return (
2405                aten.div(a, b, rounding_mode=None),
2406                aten.div(a, b, rounding_mode="floor"),
2407                aten.div(a, b, rounding_mode="trunc"),
2408                a / b,
2409                a // b,
2410            )
2411
2412        self.common(
2413            fn,
2414            (
2415                torch.randint(2**32, 2**40, [100, 100]),
2416                torch.randint(-10, -1, [100, 100]),
2417            ),
2418        )
2419
2420    def test_div8(self):
2421        def fn(a, b):
2422            return (
2423                aten.div(a, b, rounding_mode=None),
2424                aten.div(a * 0.5, b, rounding_mode=None),
2425                aten.div(a, b * 1.0, rounding_mode=None),
2426                aten.div(a, b, rounding_mode="floor"),
2427                aten.div(a, b, rounding_mode="trunc"),
2428                a / b,
2429                a // b,
2430            )
2431
2432        self.common(fn, (1024, 100))
2433
2434    def test_div9(self):
2435        def fn(x):
2436            return (torch.div(42, x), aten.true_divide(42, x), aten.div.Tensor(42, x))
2437
2438        self.common(fn, (torch.randn(8),))
2439
2440    def test_div_zero_dim(self):
2441        def fn(a, b):
2442            return (
2443                aten.div(a, b, rounding_mode=None),
2444                aten.div(a, b, rounding_mode="floor"),
2445                aten.div(a, b, rounding_mode="trunc"),
2446                a / b,
2447                a // b,
2448            )
2449
2450        for dtype in (torch.float32, torch.int64):
2451            self.common(
2452                fn,
2453                (
2454                    make_tensor(10, device=self.device, dtype=dtype),
2455                    make_tensor((), device=self.device, dtype=dtype, exclude_zero=True),
2456                ),
2457            )
2458            self.common(
2459                fn,
2460                (
2461                    make_tensor((), device=self.device, dtype=dtype),
2462                    make_tensor(10, device=self.device, dtype=dtype, exclude_zero=True),
2463                ),
2464            )
2465
2466    def test_div_prim(self):
2467        def fn(a, b):
2468            return (torch.ops.prims.div(a, b),)
2469
2470        for dtype in (torch.float32, torch.int64):
2471            self.common(
2472                fn,
2473                (
2474                    make_tensor(100, device=self.device, dtype=dtype),
2475                    make_tensor(
2476                        100, device=self.device, dtype=dtype, exclude_zero=True
2477                    ),
2478                ),
2479            )
2480
2481    def test_floordiv(self):
2482        def fn_floor_input(a, i):
2483            n = (i * 1.234) // 8.234
2484            return a + n
2485
2486        self.common(
2487            fn_floor_input,
2488            (make_tensor(10, device=self.device, dtype=torch.float32), 33),
2489        )
2490
2491        def fn_int_input(a, i):
2492            n = i // 8
2493            return a + n
2494
2495        self.common(
2496            fn_int_input, (make_tensor(10, device=self.device, dtype=torch.float32), 33)
2497        )
2498
2499    def test_div_precision(self):
2500        # Reproducer for https://github.com/pytorch/pytorch/issues/101039
2501
2502        def forward(x, y):
2503            z = x.div(y)
2504            return F.softmax(z, dim=-1)
2505
2506        query = torch.randn(1, 10, 40)
2507        key = torch.randn(1, 2, 40)
2508        x = torch.matmul(query, key.transpose(-2, -1))
2509        self.common(forward, (x, 1e-6))
2510
2511        x = torch.tensor(
2512            [
2513                [
2514                    [
2515                        [-16.1649, 5.6846, -5.1022, -9.1134],
2516                        [-11.5552, -2.2615, -12.8913, 10.6538],
2517                        [-7.1666, -5.3333, 2.0776, -9.7984],
2518                        [7.4469, -2.3948, 2.7371, 0.9201],
2519                    ],
2520                    [
2521                        [-8.0361, -16.3771, 22.7741, 4.4685],
2522                        [20.8047, -0.7771, -2.4355, -2.2299],
2523                        [3.8343, -2.0914, -2.4077, 2.2740],
2524                        [-15.8663, -2.7015, -12.5241, -3.0040],
2525                    ],
2526                    [
2527                        [-2.5139, 14.4393, -3.7186, 1.2255],
2528                        [5.6742, 14.1842, -8.5976, 16.8366],
2529                        [-9.7358, -3.0279, 11.8164, -4.0787],
2530                        [-9.0621, 8.2580, 29.9486, -2.4107],
2531                    ],
2532                    [
2533                        [7.3622, 12.5640, -20.5592, 13.6237],
2534                        [-11.5640, 0.8832, 16.7275, -2.5009],
2535                        [-2.0953, -12.2276, -26.2633, 4.5268],
2536                        [15.3329, -11.7492, 6.5650, -9.2483],
2537                    ],
2538                ],
2539                [
2540                    [
2541                        [7.9980, -4.9369, 3.1508, 5.2994],
2542                        [3.8052, 3.9514, 8.4987, -10.5045],
2543                        [-2.6827, -4.0010, -4.0611, 6.4091],
2544                        [-19.0318, 6.4073, 2.8923, 8.0250],
2545                    ],
2546                    [
2547                        [7.1650, -3.4585, 5.7720, -5.0305],
2548                        [-0.9765, -3.0086, 11.7114, 8.0555],
2549                        [-3.1027, -3.5514, 9.6182, -8.8526],
2550                        [-9.2348, -6.0239, 6.2528, -6.7221],
2551                    ],
2552                    [
2553                        [11.5936, 22.4139, -0.4089, -4.9889],
2554                        [14.8217, -2.3426, -17.6189, 3.7427],
2555                        [1.9546, -13.0902, 8.6293, -7.2457],
2556                        [-7.6900, -4.5796, 9.6332, -10.2631],
2557                    ],
2558                    [
2559                        [0.8027, -1.0955, 14.8404, -0.2673],
2560                        [3.2143, -1.8640, -2.9678, 6.5165],
2561                        [-3.9865, 6.5230, 6.3019, -0.4247],
2562                        [8.3185, -13.5076, 27.0986, -1.6792],
2563                    ],
2564                ],
2565            ]
2566        )
2567        x = torch.matmul(x, x)
2568        y = torch.tensor([[[0.6331]], [[1.6358]], [[-0.3459]], [[1.0196]]])
2569        self.common(forward, (x, y))
2570
2571    def test_div_by_zero(self):
2572        def fn(x, runtime_zero, runtime_neg_zero):
2573            zero = torch.zeros_like(x)
2574            return (
2575                x / 0.0,
2576                x / -0.0,
2577                zero / 0.0,
2578                x / zero,
2579                x / -zero,
2580                zero / zero,
2581                x / runtime_zero,
2582                # NOTE: -runtime_zero doesn't work as -(0.0) is broken in triton
2583                x / runtime_neg_zero,
2584                runtime_zero / runtime_neg_zero,
2585            )
2586
2587        a = torch.randn(10)
2588        zero = torch.zeros(10)
2589        neg_zero = -zero
2590        self.common(fn, (a, zero, neg_zero))
2591
2592    def test_both_scalars(self):
2593        def fn(a, b):
2594            return (
2595                aten.add(a, b),
2596                aten.add(b, a),
2597                aten.sub(a, b),
2598                aten.sub(b, a),
2599                aten.mul(a, b),
2600                aten.mul(b, a),
2601            )
2602
2603        self.common(fn, (4, 3.3), reference_in_float=False)
2604
2605    def test_sum_keepdims(self):
2606        def fn(a, b):
2607            return (torch.sum(a + b, -1, keepdim=True),)
2608
2609        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2610
2611    def test_large_tensor_reduction(self):
2612        if not _has_sufficient_memory(self.device, 4.5 * 1024**3):  # 4.5 GiB
2613            raise unittest.SkipTest("insufficient memory")
2614
2615        if self.device == "cpu":
2616            raise unittest.SkipTest("Fails on CPU")
2617
2618        # Test 64-bit indexing works correctly
2619        def fn(a):
2620            return torch.max(a)
2621
2622        t = torch.ones(2**32, dtype=torch.int8, device=self.device)
2623        t[-1] = 2
2624
2625        # self.common OOMs here because it copies inputs to check for mutations
2626        compiled_fn = torch._dynamo.optimize()(fn)
2627        actual = compiled_fn(t)
2628        expect = torch.tensor(2, dtype=torch.int8, device=self.device)
2629        self.assertEqual(actual, expect)
2630
2631    def test_large_broadcast_reduction(self):
2632        if self.device == "cpu":
2633            raise unittest.SkipTest("Fails on CPU")
2634
2635        # Test 64-bit indexing works correctly when inputs are less than 32-bit
2636        # but intermediate tensors require 64-bit indexing
2637        def fn(a, b):
2638            return torch.max(a + b)
2639
2640        t1 = torch.ones(1, 2**16, dtype=torch.int8, device=self.device)
2641        t2 = torch.ones(2**16, 1, dtype=torch.int8, device=self.device)
2642
2643        t1[-1, -1] = 2
2644        t2[-1, -1] = 2
2645
2646        # self.common OOMs here because it copies inputs to check for mutations
2647        compiled_fn = torch._dynamo.optimize()(fn)
2648        actual = compiled_fn(t1, t2)
2649        expect = torch.tensor(4, dtype=torch.int8, device=self.device)
2650        self.assertEqual(actual, expect)
2651
2652    def test_large_pointwise(self):
2653        if not _has_sufficient_memory(self.device, 2 * (2**31 + 1)):
2654            raise unittest.SkipTest("insufficient memory")
2655
2656        def fn(a):
2657            return a + 1
2658
2659        t = torch.ones(2**31 + 1, dtype=torch.int8, device=self.device)
2660        compiled_fn = torch._dynamo.optimize()(fn)
2661        actual = compiled_fn(t)
2662
2663        # Can't use assertEqual as it expands broadcasted inputs
2664        del t
2665        if torch.device(self.device).type == GPU_TYPE:
2666            getattr(torch, GPU_TYPE).empty_cache()
2667
2668        self.assertTrue((actual == 2).all())
2669
2670    def test_large_offset_pointwise(self):
2671        # Test 64-bit indexing is used when input views a tensor that can be
2672        # indexed with 32-bit strides but the storage offset pushes it over
2673        # INT_MAX
2674        if not _has_sufficient_memory(self.device, (2**31 + 1) + (2**30 + 1)):
2675            raise unittest.SkipTest("insufficient memory")
2676
2677        def fn(a):
2678            return a + 4
2679
2680        t = torch.ones(2**31 + 1, dtype=torch.int8, device=self.device)
2681        t[2**30 :] = 0
2682        compiled_fn = torch._dynamo.optimize()(fn)
2683        actual = compiled_fn(t[2**30 :])
2684        self.assertTrue((actual == 4).all())
2685
2686    def test_large_strided_reduction(self):
2687        # Test 64-bit indexing is used when input numel is less than INT_MAX
2688        # but stride calculations go above INT_MAX
2689        if not _has_sufficient_memory(self.device, 2**31 + 2):
2690            raise unittest.SkipTest("insufficient memory")
2691
2692        def fn(a):
2693            return torch.max(a)
2694
2695        storage = torch.ones(2**31 + 1, dtype=torch.int8, device=self.device)
2696        view = storage[::32]
2697        view[-1] = 2
2698
2699        compiled_fn = torch._dynamo.optimize()(fn)
2700        actual = compiled_fn(view)
2701        expect = torch.tensor(2, dtype=torch.int8, device=self.device)
2702        self.assertEqual(actual, expect)
2703
2704    def test_softmax(self):
2705        def fn(a, b):
2706            return (torch.softmax(a + b, -1), torch.softmax(a, 0), torch.softmax(b, 1))
2707
2708        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2709
2710    def test_log_softmax(self):
2711        def fn(a, b):
2712            return (F.log_softmax(a + b, -1), F.log_softmax(a, 0), F.log_softmax(b, 1))
2713
2714        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2715
2716    def test_transpose(self):
2717        def fn(a, b):
2718            return (
2719                torch.t(a) + b,
2720                torch.transpose(b * 2, 0, 1) + 10,
2721            )
2722
2723        self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
2724
2725    def test_permute1(self):
2726        def fn(a):
2727            return (
2728                torch.permute(a + 1, [2, 1, 4, 0, 3]) + 2,
2729                torch.permute(a, [2, 1, 4, 0, 3]) + 2,
2730            )
2731
2732        self.common(fn, (torch.randn(2, 2, 2, 2, 2),))
2733
2734    def test_permute2(self):
2735        def fn(a):
2736            a = a.unfold(0, 2, 1)
2737            a = torch.unsqueeze(a, 1)
2738            a = torch.permute(a, [0, 2, 3, -3])
2739            return (a,)
2740
2741        self.common(fn, (torch.randn(4, 4),))
2742
2743    def test_expand(self):
2744        def fn(a):
2745            return (
2746                (a + 1).expand(3, 4, 2, 3, 2) + 2,
2747                a.expand(2, 1, 2, 3, 2) + 2,
2748            ), a.expand(2, -1, 5, -1)
2749
2750        self.common(fn, (torch.randn(2, 1, 2),))
2751
2752    def test_squeeze1(self):
2753        def fn(a):
2754            return ((a + 1).squeeze() + 2, a.squeeze() + 2)
2755
2756        self.common(fn, (torch.randn(1, 2, 1, 2, 2, 1, 1),))
2757
2758    def test_squeeze2(self):
2759        def fn(a):
2760            return ((a + 1).squeeze(-1).squeeze(2) + 2, a.squeeze(0) + 2)
2761
2762        self.common(fn, (torch.randn(1, 2, 1, 2, 2, 2, 1),))
2763
2764    def test_squeeze_varargs(self):
2765        def fn(x):
2766            return x.squeeze(1, 2).clone()
2767
2768        a = torch.randn(1024, 1, 1)
2769        self.common(fn, (a,))
2770
2771    def test_simplify_loops(self):
2772        def fn(a, b):
2773            return a + b
2774
2775        self.common(
2776            fn,
2777            (
2778                torch.randn(2, 3, 4, 5, 6),
2779                torch.randn(4, 2, 3, 5, 6).permute(1, 2, 0, 3, 4),
2780            ),
2781        )
2782
2783    def test_unsqueeze(self):
2784        def fn(a):
2785            return (
2786                torch.unsqueeze(a + 1, -1) + 2,
2787                torch.unsqueeze(a, 2) + 2,
2788                torch.unsqueeze(a + 1, 0) + 2,
2789                torch.unsqueeze(a, -2) + 2,
2790            )
2791
2792        self.common(
2793            fn,
2794            (
2795                torch.randn(
2796                    2,
2797                    2,
2798                    2,
2799                    2,
2800                ),
2801            ),
2802        )
2803
2804    def test_unsqueeze_inplace(self):
2805        def fn(a):
2806            tmp1 = a + 1
2807            aten.unsqueeze_(tmp1, 2)
2808            tmp2 = aten.unsqueeze_(a + 1, 0) + 2
2809            return (tmp1, tmp2)
2810
2811        self.common(
2812            fn,
2813            (
2814                torch.randn(
2815                    2,
2816                    2,
2817                    2,
2818                    2,
2819                ),
2820            ),
2821        )
2822
2823    def test_addmm(self):
2824        def fn(a, b, c):
2825            return (torch.addmm(a + 1, b + 2, c + 3) + 4,)
2826
2827        self.common(
2828            fn,
2829            (
2830                torch.randn(8, 8),
2831                torch.randn(8, 8),
2832                torch.randn(8, 8),
2833            ),
2834        )
2835
2836    # https://github.com/pytorch/pytorch/issues/98979
2837    @skipCUDAIf(True, "cuda failed for float64 linear")
2838    @skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN")
2839    def test_linear_float64(self):
2840        mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(torch.float64)).eval()
2841        with torch.no_grad():
2842            self.common(mod, (torch.randn(2, 8).to(torch.float64),))
2843
2844    def test_linear1(self):
2845        mod = torch.nn.Sequential(
2846            torch.nn.Linear(8, 16),
2847            torch.nn.Sigmoid(),
2848            ToTuple(),
2849        )
2850        self.common(mod, (torch.randn(2, 8),))
2851
2852    def test_linear2(self):
2853        mod = torch.nn.Sequential(
2854            torch.nn.Linear(8, 8),
2855            torch.nn.ReLU(),
2856            torch.nn.Linear(8, 8),
2857            torch.nn.ReLU(),
2858            torch.nn.Linear(8, 8),
2859            torch.nn.ReLU(),
2860            torch.nn.Linear(8, 8),
2861            torch.nn.ReLU(),
2862        )
2863        self.common(
2864            mod,
2865            (torch.randn(2, 8),),
2866            atol=1e-3,
2867            rtol=0.01,
2868        )
2869
2870    def test_bmm1(self):
2871        def fn(a, b):
2872            return (
2873                torch.bmm(a, b),
2874                torch.bmm(a + 1, b + 2) + 3,
2875            )
2876
2877        self.common(
2878            fn,
2879            (
2880                torch.randn(2, 8, 8),
2881                torch.randn(2, 8, 8),
2882            ),
2883            check_lowp=False,
2884        )
2885        self.common(
2886            fn,
2887            (
2888                torch.randn(1, 16, 8),
2889                torch.randn(1, 8, 10),
2890            ),
2891            check_lowp=False,
2892        )
2893
2894    def test_bmm2(self):
2895        def fn(a, b):
2896            return torch.bmm(a.permute(0, 2, 1), b)
2897
2898        self.common(
2899            fn,
2900            (
2901                torch.randn(1, 8, 8),
2902                torch.randn(1, 8, 8),
2903            ),
2904            check_lowp=False,
2905        )
2906
2907    @skipIfPy312  # segfaults
2908    @config.patch(force_mixed_mm=True)
2909    def test_mixed_mm(self):
2910        def fn(a, b):
2911            return torch.mm(a, b.to(a.dtype))
2912
2913        self.common(
2914            fn,
2915            (
2916                torch.randn(8, 8),
2917                torch.randint(-128, 127, (8, 8), dtype=torch.int8),
2918            ),
2919            check_lowp=True,
2920        )
2921
2922    @skipIfPy312  # segfaults
2923    @config.patch(force_mixed_mm=True)
2924    def test_mixed_mm2(self):
2925        def fn(a, b, scale, bias):
2926            return torch.mm(a, b.to(a.dtype)) * scale + bias
2927
2928        self.common(
2929            fn,
2930            (
2931                torch.randn(8, 8),
2932                torch.randint(-128, 127, (8, 8), dtype=torch.int8),
2933                torch.randn(8),
2934                torch.randn(8),
2935            ),
2936            check_lowp=True,
2937        )
2938
2939    @skipIfPy312  # segfaults
2940    @config.patch(force_mixed_mm=True)
2941    def test_mixed_mm3(self):
2942        def fn(a, b):
2943            return torch.mm(a, b.to(a.dtype))
2944
2945        # (256, 256) @ (256, 256) so different block sizes are tried out during autotuning
2946        self.common(
2947            fn,
2948            (
2949                torch.randn(256, 256),
2950                torch.randint(-128, 127, (256, 256), dtype=torch.int8),
2951            ),
2952            check_lowp=True,
2953            rtol=0.01,
2954            atol=0.1,
2955        )
2956
2957    @with_tf32_off
2958    @config.patch(use_mixed_mm=True)
2959    def test_uint4x2_mixed_mm(self):
2960        def fn(a, b):
2961            return torch.mm(
2962                a,
2963                torch.cat((b & 0xF, b >> 4), 1)
2964                .reshape(-1, b.shape[1])
2965                .to(a.dtype)
2966                .sub(8),
2967            )
2968
2969        self.common(
2970            fn,
2971            (
2972                torch.randn(8, 8),
2973                torch.randint(0, 255, (4, 8), dtype=torch.uint8),
2974            ),
2975            check_lowp=True,
2976        )
2977
2978    @expectedFailureXPU
2979    def test_mm_mixed_dtype(self):
2980        def fn(a, b):
2981            return torch.mm(a, b)
2982
2983        t1 = torch.arange(6, dtype=torch.float, device=self.device).view(2, 3)
2984        t2 = torch.arange(9, dtype=torch.int64, device=self.device).view(3, 3)
2985
2986        msg = "expected .* and .* to have the same dtype, but got: .* != .*"
2987        with self.assertRaisesRegex(RuntimeError, msg):
2988            torch.compile(fn)(t1, t2)
2989        with self.assertRaisesRegex(RuntimeError, msg):
2990            fn(t1, t2)
2991
2992    @expectedFailureXPU
2993    def test_linear_mixed_dtype(self):
2994        class Net(nn.Module):
2995            def __init__(self):
2996                super(Net, self).__init__()  # noqa: UP008
2997                self.fc1 = nn.Linear(3, 3)
2998
2999            def forward(self, x):
3000                x = self.fc1(x.permute(1, 2, 0))
3001                return x
3002
3003        fn = Net().to(self.device)
3004        t = torch.arange(27, device=self.device).view(3, 3, 3)
3005
3006        msg = "expected .* and .* to have the same dtype, but got: .* != .*"
3007        with self.assertRaisesRegex(RuntimeError, msg):
3008            fn(t)
3009        with self.assertRaisesRegex(RuntimeError, msg):
3010            with torch.no_grad():
3011                torch.compile(fn)(t)
3012        # TODO: Autograd internal assertion
3013        msg = r".*isDifferentiableType\(variable.scalar_type\(\)\) INTERNAL ASSERT FAILED.*"
3014        with self.assertRaisesRegex(RuntimeError, msg):
3015            torch.compile(fn)(t)
3016
3017    def test_scalar_input(self):
3018        def fn(x, y):
3019            a = torch.div(x, y, rounding_mode="floor")
3020            return a
3021
3022        self.common(fn, [torch.randint(5, (1, 8)), 5400])
3023
3024    @torch._dynamo.config.patch(dynamic_shapes=True)
3025    @torch._dynamo.config.patch(assume_static_by_default=False)
3026    def test_scalar_output(self):
3027        def fn(arg0_1, arg2_1):
3028            arg1_1 = arg2_1.size(1)
3029            view = torch.ops.aten.view.default(arg2_1, [-1, arg1_1])
3030            embedding = torch.ops.aten.embedding.default(arg0_1, view)
3031            full = torch.ops.aten.full.default([1, arg1_1], 1, dtype=torch.float32)
3032            return (full, arg1_1, embedding)
3033
3034        arg0_1 = rand_strided((32128, 768), (768, 1), device="cpu", dtype=torch.float32)
3035        arg2_1 = rand_strided((1, 22), (22, 1), device="cpu", dtype=torch.int64)
3036        self.common(fn, [arg0_1, arg2_1])
3037
3038    def test_shape_prop_torch_ones(self):
3039        class Model(torch.nn.Module):
3040            def forward(self, attention_scores):
3041                extended_attention_mask = torch.ones(
3042                    8, 1, 1, 512, device=attention_scores.device
3043                )
3044                attention_scores = attention_scores + extended_attention_mask
3045
3046                return attention_scores
3047
3048        mod = Model().eval()
3049        with torch.no_grad():
3050            self.common(
3051                mod,
3052                (torch.randn(8, 12, 512, 512),),
3053            )
3054
3055    @slowTest
3056    @expectedFailureCodegenDynamic
3057    @config.patch({"freezing": True})
3058    def test_conv_bn_fuse(self):
3059        # For gpu path, there is an accuracy issue
3060        if self.device == GPU_TYPE:
3061            raise unittest.SkipTest("only support cpu conv bn test")
3062
3063        # fails dynamic check which bn is fused, and there will not have loops vars.
3064        input_shapes = {1: (112,), 2: (112, 112), 3: (55, 55, 55)}
3065        conv_modules = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
3066        bn_modules = {
3067            1: torch.nn.BatchNorm1d,
3068            2: torch.nn.BatchNorm2d,
3069            3: torch.nn.BatchNorm3d,
3070        }
3071        options = itertools.product(
3072            [1, 2, 3],
3073            [True, False],
3074            [1, 3],
3075            [1, 2],
3076            [1, 4],
3077        )
3078
3079        for (
3080            dim,
3081            bias,
3082            kernel_size,
3083            dilation,
3084            groups,
3085        ) in options:
3086            oC = 32 * groups
3087            iC = 3 * groups
3088            x_shape = (1, iC) + input_shapes[dim]
3089            mod = torch.nn.Sequential(
3090                conv_modules[dim](
3091                    iC,
3092                    oC,
3093                    kernel_size=kernel_size,
3094                    dilation=dilation,
3095                    groups=groups,
3096                    bias=bias,
3097                ),
3098                bn_modules[dim](oC),
3099            ).eval()
3100            test_memory_format = [torch.contiguous_format]
3101            # TODO: GPU path doesn't support channels_last now.
3102            if not HAS_GPU and dim > 1:
3103                channels_last = (
3104                    torch.channels_last if dim == 2 else torch.channels_last_3d
3105                )
3106                test_memory_format.append(channels_last)
3107            for memory_format in test_memory_format:
3108                v = torch.randn(x_shape, dtype=torch.float32).to(
3109                    memory_format=memory_format
3110                )
3111                with torch.no_grad():
3112                    self.common(
3113                        mod,
3114                        (v,),
3115                    )
3116
3117    def test_conv_functional_bn_fuse(self):
3118        # For gpu path, there is an accuracy issue
3119        if self.device == GPU_TYPE:
3120            raise unittest.SkipTest("only support cpu conv bn test")
3121
3122        # Define a BatchNorm using functional BN.
3123        class BatchNorm(torch.nn.BatchNorm2d):
3124            def __init__(
3125                self,
3126                num_features,
3127                eps=1e-5,
3128                momentum=0.1,
3129                affine=True,
3130                track_running_stats=True,
3131                device=None,
3132                dtype=None,
3133            ):
3134                factory_kwargs = {"device": device, "dtype": dtype}
3135                super().__init__(
3136                    num_features,
3137                    eps=eps,
3138                    momentum=momentum,
3139                    affine=affine,
3140                    track_running_stats=track_running_stats,
3141                    **factory_kwargs,
3142                )
3143
3144            def forward(self, x):
3145                if self.momentum is None:
3146                    exponential_average_factor = 0.0
3147                else:
3148                    exponential_average_factor = self.momentum
3149
3150                if self.training and self.track_running_stats:
3151                    # TODO: if statement only here to tell the jit to skip emitting this when it is None
3152                    if self.num_batches_tracked is not None:  # type: ignore[has-type]
3153                        self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
3154                        if self.momentum is None:  # use cumulative moving average
3155                            exponential_average_factor = 1.0 / float(
3156                                self.num_batches_tracked
3157                            )
3158                        else:  # use exponential moving average
3159                            exponential_average_factor = self.momentum
3160                if self.training:
3161                    bn_training = True
3162                else:
3163                    bn_training = (self.running_mean is None) and (
3164                        self.running_var is None
3165                    )
3166                x = F.batch_norm(
3167                    x,
3168                    # If buffers are not to be tracked, ensure that they won't be updated
3169                    (
3170                        self.running_mean
3171                        if not self.training or self.track_running_stats
3172                        else None
3173                    ),
3174                    (
3175                        self.running_var
3176                        if not self.training or self.track_running_stats
3177                        else None
3178                    ),
3179                    self.weight,
3180                    self.bias,
3181                    bn_training,
3182                    exponential_average_factor,
3183                    self.eps,
3184                )
3185                return x
3186
3187        v = torch.randn(1, 3, 556, 56, dtype=torch.float32)
3188        mod = torch.nn.Sequential(
3189            torch.nn.Conv2d(
3190                3,
3191                64,
3192                kernel_size=3,
3193                dilation=1,
3194                groups=1,
3195                bias=True,
3196            ),
3197            BatchNorm(64),
3198        ).eval()
3199        with torch.no_grad():
3200            self.common(
3201                mod,
3202                (v,),
3203            )
3204
3205    @skipIfRocm
3206    def test_conv_inference_heuristics(self):
3207        if self.device != GPU_TYPE:
3208            raise unittest.SkipTest(f"{GPU_TYPE} only test")
3209
3210        in_channels = 6
3211        out_channels = 6
3212        kernel_size = 3
3213        groups = 3
3214
3215        grouped_conv = nn.Conv2d(
3216            in_channels, out_channels, kernel_size, groups=groups
3217        ).to(self.device)
3218
3219        input_tensor = torch.randn(1, in_channels, 10, 10).to(self.device)
3220
3221        # Perform the forward pass
3222        @torch.compile()
3223        def foo(m, inp):
3224            return m(inp)
3225
3226        with torch.no_grad():
3227            _, code = run_and_get_code(foo, grouped_conv, input_tensor)
3228            # no to channels last permuting before kernel
3229            FileCheck().check_not(".run(").check(".convolution(").run(code[0])
3230
3231        # in out should do channels last in inference
3232        in_channels = 8
3233        out_channels = 4
3234        kernel_size = 3
3235
3236        # Create the convolution layer
3237        conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size).to(self.device)
3238
3239        input_tensor = torch.randn(1, in_channels, 10, 10).to(self.device)
3240
3241        with torch.no_grad():
3242            _, code = run_and_get_code(foo, conv_layer, input_tensor)
3243            # should be channels last permuting before kernel
3244            FileCheck().check(".run(").check(".convolution(").run(code[0])
3245
3246    def test_upsample_cat_conv(self):
3247        if self.device == GPU_TYPE:
3248            raise unittest.SkipTest("only support cpu upsample_cat_conv test")
3249
3250        class M(torch.nn.Module):
3251            def __init__(
3252                self,
3253                **kwargs,
3254            ):
3255                super().__init__()
3256                self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
3257                self.conv = torch.nn.Conv2d(
3258                    8,
3259                    5,
3260                    kernel_size=1,
3261                    padding=0,
3262                    stride=1,
3263                    dilation=1,
3264                    **kwargs,
3265                )
3266
3267            def forward(self, x, y):
3268                x = self.upsample(x)
3269                z = torch.cat([x, y], dim=1)
3270                z = self.conv(z)
3271                return z
3272
3273        v1 = torch.randn([8, 2, 12, 26])
3274        v2 = torch.randn([8, 6, 24, 52])
3275
3276        with torch.no_grad():
3277            self.common(
3278                M().eval(),
3279                (v1, v2),
3280            )
3281
3282    def test_aliased_buffer_reuse(self):
3283        def fn(x, y):
3284            x = 2 * x
3285            y = 2 * y
3286            c = torch.cat([x, y], dim=-1)
3287            d = 1 + c
3288            m = torch.mm(d, d)
3289            return m[:, :2] + x
3290
3291        self.common(fn, (torch.randn(4, 2), torch.randn(4, 2)), check_lowp=False)
3292
3293    def test_slice_view_with_graph_break(self):
3294        def fn():
3295            a = torch.tensor([1], device=self.device)
3296            a = a[0:1]
3297            b = a.squeeze()
3298            a[0] = 0
3299            if a[0] < 1e5:
3300                pass
3301            a[0] = 2
3302            return b
3303
3304        expect = fn()
3305        opt_fn = torch.compile(fn)
3306        actual = opt_fn()
3307        self.assertEqual(expect, actual)
3308
3309    def test_view_detach(self):
3310        def fn(a):
3311            return a[0].detach()
3312
3313        self.common(
3314            fn,
3315            (torch.randn([4, 4], requires_grad=True),),
3316        )
3317
3318    def test_gather1(self):
3319        def fn(a, b):
3320            return (
3321                torch.gather(a.expand([4, 5, 10, 6]), 3, b + 1),
3322                torch.gather(a.expand([4, 5, 10, 6]), -1, b + 1),
3323            )
3324
3325        self.common(
3326            fn,
3327            (
3328                torch.randn([1, 1, 10, 6]),
3329                torch.randint(5, [4, 5, 10, 1], dtype=torch.int64),
3330            ),
3331        )
3332
3333    def test_gather2(self):
3334        # 0d tensor
3335        def fn(a, b):
3336            return torch.gather(a, 0, b) + torch.gather(a, -1, b)
3337
3338        x = torch.tensor(123)
3339        y = torch.tensor(0)
3340        self.assertEqual(fn(x, y), x + x)
3341
3342    def test_gather3(self):
3343        def fn(a, b):
3344            return torch.gather(a, 1, b, sparse_grad=True)
3345
3346        self.common(
3347            fn,
3348            (
3349                torch.randn([4, 5, 10, 6], requires_grad=True),
3350                torch.randint(5, [4, 5, 10, 1], dtype=torch.int64),
3351            ),
3352        )
3353
3354    def test_slice1(self):
3355        def fn(a):
3356            return (
3357                a[:, :10, 0] + a[:, 10:, 0],
3358                (a + 1)[:, :10, 0] + (a + 1)[:, 10:, 0],
3359                a[:, -30:, 0],  # negative index out of range
3360                a[:, :-30, 0],  # negative index out of range
3361            )
3362
3363        self.common(
3364            fn,
3365            (torch.randn([2, 20, 2]),),
3366        )
3367
3368    def test_slice2(self):
3369        def fn(a):
3370            return (
3371                a[:-1, ::2, -1] + a[-1:, 1::2, -2],
3372                (a + 1)[:-1, ::2, -1] + (a + 2)[-1:, 1::2, -2],
3373            )
3374
3375        self.common(
3376            fn,
3377            (torch.randn([2, 20, 2]),),
3378        )
3379
3380    # It's a view so it doens't generate a kernel
3381    @expectedFailureCodegenDynamic
3382    def test_slice3(self):
3383        def fn(a, b):
3384            return torch.ops.aten.slice.Tensor(a, 0, 0, -b)
3385
3386        x = torch.rand(48, 3, 512, 512)
3387        self.common(fn, (x, 2))
3388
3389    @expectedFailureCodegenDynamic
3390    def test_slice4(self):
3391        # empty slices that require clamping the start or end
3392        def fn(a):
3393            return (
3394                aten.slice.Tensor(a, 0, 2, 0, 1),
3395                aten.slice.Tensor(a, 0, a.shape[0], a.shape[0] + 10, 1),
3396                aten.slice.Tensor(a, 0, -20, 0, 1),
3397                aten.slice.Tensor(a, 0, -20, -16, 1),
3398            )
3399
3400        x = torch.rand(10)
3401        self.common(fn, (x,))
3402
3403    def test_split_with_list(self):
3404        def fn(a, sizes):
3405            return [t + 1.0 for t in torch.split(a * 2.0, sizes, -1)]
3406
3407        self.common(fn, (torch.randn(2, 2, 10), [3, 3, 4]))
3408        self.common(fn, (torch.randn(2, 2, 10), [4, 3, 3]))
3409        self.common(fn, (torch.randn(2, 2, 10), [1, 2, 3, 4]))
3410
3411    def test_split_with_integer(self):
3412        # argument `split_size_or_sections` is integer
3413        @torch.compile(dynamic=True)
3414        def f(x, sizes):
3415            return torch.split(x, sizes, -1)
3416
3417        # split into equally sized chunks, 10 = 5 + 5
3418        r1, r2 = f(torch.randn(2, 10), 5)
3419        self.assertTrue(r1.size() == (2, 5))
3420        self.assertTrue(r2.size() == (2, 5))
3421
3422        # split into equally sized chunks, 12 = 4 + 4 + 4
3423        r1, r2, r3 = f(torch.randn(2, 12), 4)
3424        self.assertTrue(r1.size() == (2, 4))
3425        self.assertTrue(r2.size() == (2, 4))
3426        self.assertTrue(r3.size() == (2, 4))
3427
3428        # split unevenly, 10 = 3 + 3 + 3 + 1
3429        r1, r2, r3, r4 = f(torch.randn(2, 10), 3)
3430        self.assertTrue(r1.size() == (2, 3))
3431        self.assertTrue(r2.size() == (2, 3))
3432        self.assertTrue(r3.size() == (2, 3))
3433        self.assertTrue(r4.size() == (2, 1))
3434
3435    def test_split_failed(self):
3436        @torch._dynamo.optimize("inductor")
3437        def fn(a):
3438            return torch.split(a, [2, 1, 1], dim=1)
3439
3440        with self.assertRaisesRegex(RuntimeError, ""):
3441            fn(torch.randn(1, 5))
3442
3443    def test_inductor_assert(self):
3444        @torch._dynamo.optimize("inductor", dynamic=True)
3445        def fn(a):
3446            assert a.shape[0] >= 2 and a.shape[1] >= 4
3447            return a.cos()
3448
3449        inp = torch.randn(2, 4, 6)
3450        torch._dynamo.mark_dynamic(inp, 0)
3451        torch._dynamo.mark_dynamic(inp, 1)
3452        self.assertEqual(fn(inp), inp.cos())
3453
3454    def test_split(self):
3455        def fn(a):
3456            t = torch.split(a, 3, -1)
3457            return (t[0], t[1], t[2], t[3])
3458
3459        def fn2(a):
3460            return fn(a + 1)
3461
3462        self.common(
3463            fn,
3464            (torch.randn([2, 2, 10]),),
3465        )
3466
3467        self.common(
3468            fn2,
3469            (torch.randn([2, 2, 10]),),
3470        )
3471
3472    def test_to_dtype(self):
3473        def fn(a, b):
3474            return (
3475                aten._to_copy(a, dtype=6),
3476                aten._to_copy(b + 1, dtype=6),
3477                aten.to(b, torch.float64),
3478                aten.to(b, torch.bool),
3479            )
3480
3481        self.common(
3482            fn,
3483            (
3484                torch.randn([2, 2, 10]),
3485                torch.randn([2, 2, 10], dtype=torch.float64),
3486            ),
3487        )
3488
3489    @requires_gpu()
3490    def test_to_device(self):
3491        def fn(a):
3492            if a.device.type == "cpu":
3493                return aten._to_copy(
3494                    a, device=torch.device(GPU_TYPE), dtype=6, layout=0
3495                )
3496            else:
3497                return aten._to_copy(a, device=torch.device("cpu"), dtype=6, layout=0)
3498
3499        self.common(
3500            fn,
3501            (torch.randn([2, 2, 10]),),
3502        )
3503
3504    def test_to_memory_format(self):
3505        def fn(a, memory_format):
3506            return a.to(memory_format=memory_format)
3507
3508        self.common(
3509            fn,
3510            (torch.randn([2, 2, 10, 10]), torch.channels_last),
3511        )
3512        self.common(
3513            fn,
3514            (
3515                torch.randn([2, 2, 10, 10]).to(memory_format=torch.channels_last),
3516                torch.contiguous_format,
3517            ),
3518        )
3519
3520    @requires_gpu()
3521    def test_to_device_constant(self):
3522        def fn(a):
3523            d1 = a.device.type
3524            if d1 == "cpu":
3525                d2 = GPU_TYPE
3526            else:
3527                d2 = "cpu"
3528
3529            const1 = torch.as_tensor(list(range(64)), device=d2)
3530            return (
3531                torch.arange(10, device=d2).to(d1) + a,
3532                const1.to(d1),
3533                (const1 + 1).to(d1),
3534            )
3535
3536        self.common(
3537            fn,
3538            (torch.randn([10]),),
3539        )
3540
3541    @requires_gpu()
3542    def test_multi_device(self):
3543        def fn(x):
3544            x = x + 1
3545            x = x + 2
3546            x = x.to(device=GPU_TYPE)
3547            x = x + 3
3548            x = x + 4
3549            x = x.cpu()
3550            x = x + 5
3551            x = x + 6
3552            x = x.to(device=GPU_TYPE)
3553            x = x + 7
3554            x = x + 8
3555            x = x.cpu()
3556            x = x + 9
3557            x = x + 10
3558            return x
3559
3560        self.common(
3561            fn,
3562            (torch.randn([2, 2, 10]),),
3563            check_lowp=False,  # cpu doesn't understand fp16, and there are explicit .cpu() calls
3564        )
3565
3566    @skipIfRocm
3567    @requires_multigpu()
3568    def test_multi_gpu_device(self):
3569        # TODO: https://github.com/pytorch/pytorch/issues/92627
3570        x = torch.rand([4], device=GPU_TYPE)
3571
3572        def fn(x, y):
3573            r = torch.ops.aten.div(x, y)
3574            r = r.to(f"{GPU_TYPE}:1")
3575            return 2 * r
3576
3577        self.common(fn, (torch.randn(4), torch.randn(4)), check_lowp=False)
3578
3579    @requires_multigpu()
3580    def test_multi_gpu_recompile_on_index(self):
3581        torch.set_float32_matmul_precision("high")
3582
3583        def gemm(x, y):
3584            return x @ y
3585
3586        failed_guard = None
3587
3588        def fail(guard):
3589            nonlocal failed_guard
3590            failed_guard = guard
3591
3592        gemm_opt = torch._dynamo.optimize("inductor", guard_fail_fn=fail)(gemm)
3593
3594        x0 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:0")
3595        y0 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:0")
3596
3597        gemm_opt(x0, y0)
3598
3599        x1 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:1")
3600        y1 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:1")
3601
3602        gemm_opt(x1, y1)
3603        self.assertTrue(failed_guard is not None)
3604        self.assertTrue(
3605            "tensor 'L['x']' Tensor device index mismatch. Expected device index to be"
3606            in failed_guard.reason
3607        )
3608
3609    def test_unbind(self):
3610        def fn(a):
3611            return torch.unbind(a), torch.unbind(a, -1)
3612
3613        self.common(
3614            fn,
3615            (torch.randn([4, 4, 4]),),
3616        )
3617
3618    @skipIfRocm
3619    def test_convolution1(self):
3620        m = torch.nn.Sequential(
3621            torch.nn.Conv2d(5, 6, [3, 3]),
3622            torch.nn.ReLU(),
3623            ToTuple(),
3624        )
3625
3626        self.common(
3627            m,
3628            (torch.randn([2, 5, 16, 16]),),
3629            # Mismatched elements: 10 / 2352 (0.4%)
3630            # Greatest absolute difference: 5.7220458984375e-05 at index (0, 3, 12, 12) (up to 1e-05 allowed)
3631            # Greatest relative difference: 0.06512477175897748 at index (0, 4, 11, 9) (up to 0.001 allowed)
3632            atol=6e-5,
3633            rtol=0.001,
3634        )
3635
3636    def test_convolution2(self):
3637        def fn(x, w, b):
3638            # transposed conv
3639            return (aten.convolution(x, w, b, [4], [0], [1], True, [0], 1),)
3640
3641        self.common(
3642            fn,
3643            (
3644                torch.randn([2, 32, 90]),
3645                torch.randn([32, 16, 8]),
3646                torch.randn([16]),
3647            ),
3648            check_lowp=False,
3649        )
3650
3651    @skipIfRocm
3652    def test_convolution3(self):
3653        # Test stride or padding or dilation is 1 element list.
3654        m = torch.nn.Sequential(
3655            torch.nn.Conv2d(5, 6, [3, 3], stride=[1], padding=[0], dilation=[1]),
3656            torch.nn.ReLU(),
3657            ToTuple(),
3658        )
3659
3660        self.common(
3661            m,
3662            (torch.randn([2, 5, 16, 16]),),
3663            atol=6e-5,
3664            rtol=0.001,
3665        )
3666
3667    @skipIfRocm
3668    def test_convolution4(self):
3669        def fn(x, w):
3670            x = F.conv2d(x, w, groups=w.shape[0])
3671            return x.sum()
3672
3673        self.common(
3674            fn,
3675            (
3676                torch.randn([2, 3, 16, 20]),
3677                torch.randn([3, 1, 5, 5]),
3678            ),
3679        )
3680
3681    def test_conv2d_channels_last(self):
3682        if self.device == GPU_TYPE:
3683            raise unittest.SkipTest("only support cpu conv2d channels_last")
3684
3685        m = torch.nn.Sequential(
3686            torch.nn.Conv2d(3, 3, 1, 1),
3687            ToTuple(),
3688        )
3689        # only weight is channels_last
3690        self.common(
3691            m.to(memory_format=torch.channels_last),
3692            (torch.randn([2, 3, 16, 16]),),
3693            check_lowp=False,
3694        )
3695        # only activation is channels_last
3696        self.common(
3697            m,
3698            (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),),
3699            check_lowp=False,
3700        )
3701        # activation and weight are all channels_last
3702        self.common(
3703            m.to(memory_format=torch.channels_last),
3704            (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),),
3705            check_lowp=False,
3706        )
3707
3708    def test_conv2d_backward_channels_last(self):
3709        def fn(grad_output, inp, weight):
3710            convolution_backward_8 = torch.ops.aten.convolution_backward.default(
3711                grad_output,
3712                inp,
3713                weight,
3714                [320],
3715                [1, 1],
3716                [0, 0],
3717                [1, 1],
3718                False,
3719                [0, 0],
3720                1,
3721                [True, True, True],
3722            )
3723            return convolution_backward_8
3724
3725        # only weight is channels_last
3726        self.common(
3727            fn,
3728            (
3729                torch.randn([2, 320, 8, 8]),
3730                torch.randn([2, 2048, 8, 8]),
3731                torch.randn([320, 2048, 1, 1]).to(memory_format=torch.channels_last),
3732            ),
3733            check_lowp=False,
3734        )
3735
3736    def test_conv3d_channels_last(self):
3737        if self.device == GPU_TYPE:
3738            raise unittest.SkipTest("only support cpu conv3d channels_last")
3739
3740        m = torch.nn.Sequential(
3741            torch.nn.Conv3d(3, 3, 1, 1),
3742            ToTuple(),
3743        )
3744        # only weight is channels_last
3745        self.common(
3746            m.to(memory_format=torch.channels_last_3d),
3747            (torch.randn([2, 3, 16, 16, 16]),),
3748        )
3749        # only activation is channels_last
3750        self.common(
3751            m,
3752            (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
3753        )
3754        # activation and weight are all channels_last
3755        self.common(
3756            m.to(memory_format=torch.channels_last_3d),
3757            (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
3758        )
3759
3760    def test_adaptive_avg_pool2d1(self):
3761        def fn(x):
3762            return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d(
3763                x + 1, (2, 5)
3764            )
3765
3766        self.common(
3767            fn,
3768            (torch.randn(2, 4, 16, 16),),
3769            check_lowp=False,
3770        )
3771
3772        # lowering to avg_pool2d case
3773        self.common(
3774            fn,
3775            (torch.randn(2, 4, 3, 3),),
3776        )
3777
3778        # no-op case
3779        self.common(
3780            fn,
3781            (torch.randn(2, 4, 6, 6),),
3782        )
3783
3784    def test_adaptive_avg_pool2d2(self):
3785        # Big kernel size, use fallback
3786        def fn(x):
3787            return aten._adaptive_avg_pool2d(x, (4, 4))
3788
3789        torch._inductor.metrics.generated_kernel_count = 0
3790        self.common(
3791            fn,
3792            (torch.randn(2, 4, 21, 21),),
3793            check_lowp=False,
3794        )
3795        assertGeneratedKernelCountEqual(self, 0)
3796
3797    def test_adaptive_max_pool2d1(self):
3798        def fn(x):
3799            return aten.adaptive_max_pool2d(x, (6, 6))
3800
3801        self.common(
3802            fn,
3803            (torch.randn(2, 4, 16, 16),),
3804            check_lowp=False,
3805        )
3806
3807        # lowering to max_pool2d case
3808        self.common(
3809            fn,
3810            (torch.randn(2, 4, 3, 3),),
3811        )
3812
3813        # no-op case
3814        self.common(
3815            fn,
3816            (torch.randn(2, 4, 6, 6),),
3817        )
3818
3819    def test_adaptive_max_pool2d2(self):
3820        # Big kernel size, use fallback
3821        def fn(x):
3822            return aten.adaptive_max_pool2d(x, (4, 4))
3823
3824        torch._inductor.metrics.generated_kernel_count = 0
3825        self.common(
3826            fn,
3827            (torch.randn(2, 4, 21, 21),),
3828            check_lowp=False,
3829        )
3830        assertGeneratedKernelCountEqual(self, 0)
3831
3832    def test_fractional_max_pool2d1(self):
3833        def fn(x, samples):
3834            return aten.fractional_max_pool2d(x, (3, 3), (2, 2), samples)
3835
3836        self.common(
3837            fn, (torch.randn(1, 4, 16, 16), torch.rand(1, 4, 2)), check_lowp=False
3838        )
3839
3840    def test_fractional_max_pool2d2(self):
3841        # fallback for larger kernel size
3842
3843        def fn(x, samples):
3844            return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples)
3845
3846        torch._inductor.metrics.generated_kernel_count = 0
3847        self.common(
3848            fn,
3849            (torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)),
3850            check_lowp=False,
3851        )
3852        assertGeneratedKernelCountEqual(self, 0)
3853
3854    def test_fractional_max_pool2d3(self):
3855        def fn(x, samples):
3856            return aten.fractional_max_pool2d(x, (1, 1), (16, 16), samples)
3857
3858        self.common(
3859            fn, (torch.randn(2, 4, 16, 16), torch.rand(2, 4, 2)), check_lowp=False
3860        )
3861
3862    @config.patch(fallback_random=True)
3863    def test_fractional_max_pool2d4(self):
3864        random.seed(1234)
3865        torch.manual_seed(1234)
3866
3867        # check rectangular kernel/output size
3868
3869        def fn(x):
3870            return torch.nn.functional.fractional_max_pool2d_with_indices(
3871                x, (4, 3), (3, 2)
3872            )
3873
3874        self.common(fn, (torch.randn(1, 4, 16, 16),), check_lowp=False)
3875
3876    def test_multi_threading(self):
3877        model = torch.nn.Linear(2, 3).eval()
3878        inp = torch.randn(4, 2)
3879
3880        num_run = 3
3881
3882        def run_weights_sharing_model(m, inp):
3883            with torch.no_grad():
3884                for i in range(num_run):
3885                    y = m(inp)
3886
3887        numb_instance = 2
3888        threads = []
3889        compiled_m = torch.compile(model)
3890        for i in range(1, numb_instance + 1):
3891            thread = threading.Thread(
3892                target=run_weights_sharing_model, args=(compiled_m, inp)
3893            )
3894            threads.append(thread)
3895            thread.start()
3896        for thread in threads:
3897            thread.join()
3898
3899    @unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging")
3900    def test_adaptive_avg_pool2d_low_prec(self):
3901        class Model(torch.nn.Module):
3902            def __init__(self):
3903                super().__init__()
3904                self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
3905
3906            def forward(self, x):
3907                x = self.avgpool(x)
3908                return x
3909
3910        mod = Model().to(self.device)
3911        for dtype in [torch.half, torch.bfloat16]:
3912            x = torch.randn(4, 3, 7, 7, device=self.device).to(dtype=dtype)
3913            opt_mod = torch.compile(mod)
3914            res = opt_mod(x)
3915            expected = mod(x)
3916            self.assertTrue(torch.allclose(res, expected))
3917
3918    def test_buffer_copied_in_graph(self):
3919        class MyModel(torch.nn.Module):
3920            def __init__(self):
3921                super().__init__()
3922                self.register_buffer("buf", torch.zeros(1))
3923                self.w1 = torch.nn.Parameter(torch.zeros(1))
3924                self.w2 = torch.nn.Parameter(torch.zeros(1))
3925
3926            def forward(self, x):
3927                self.buf.add_(1)
3928                return (self.w1 * x * self.w2).sum() + self.buf.sum()
3929
3930        model_for_eager = MyModel().to(self.device)
3931        model_for_compile = copy.deepcopy(model_for_eager)
3932
3933        eager_version_counters = [
3934            buffer._version for _, buffer in model_for_eager.named_buffers()
3935        ]
3936        compile_version_counters = [
3937            buffer._version for _, buffer in model_for_compile.named_buffers()
3938        ]
3939
3940        compiled_f = torch.compile(model_for_compile, backend="inductor")
3941
3942        inp_ref = torch.ones(1, requires_grad=True, device=self.device)
3943        inp_test = torch.ones(1, requires_grad=True, device=self.device)
3944
3945        out_ref = model_for_eager(inp_ref.clone())
3946        out_test = compiled_f(inp_test.clone())
3947
3948        eager_version_counters_after = [
3949            buffer._version for _, buffer in model_for_eager.named_buffers()
3950        ]
3951        compile_version_counters_after = [
3952            buffer._version for _, buffer in model_for_compile.named_buffers()
3953        ]
3954
3955        eager_delta = list(
3956            map(operator.sub, eager_version_counters_after, eager_version_counters)
3957        )
3958        compile_delta = list(
3959            map(operator.sub, compile_version_counters_after, compile_version_counters)
3960        )
3961
3962        self.assertEqual(eager_delta, compile_delta)
3963
3964    def test_buffer_copied_in_graph_with_different_shapes(self):
3965        class MyModel(torch.nn.Module):
3966            def __init__(self):
3967                super().__init__()
3968                self.register_buffer("buf", torch.ones(4, 4))
3969                self.w = torch.nn.Parameter(
3970                    torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]])
3971                )
3972
3973            def forward(self, x):
3974                self.buf.add_(1)
3975                return (self.w @ x).sum() + self.buf.sum()
3976
3977        model_for_eager = MyModel().to(self.device)
3978        model_for_compile = copy.deepcopy(model_for_eager)
3979
3980        eager_version_counters = [
3981            buffer._version for _, buffer in model_for_eager.named_buffers()
3982        ]
3983        compile_version_counters = [
3984            buffer._version for _, buffer in model_for_compile.named_buffers()
3985        ]
3986
3987        compiled_f = torch.compile(model_for_compile, backend="inductor")
3988
3989        inp_ref = torch.ones(2, 4, requires_grad=True, device=self.device)
3990        inp_test = torch.ones(2, 4, requires_grad=True, device=self.device)
3991
3992        out_ref = model_for_eager(inp_ref.clone())
3993        out_test = compiled_f(inp_test.clone())
3994
3995        eager_version_counters_after = [
3996            buffer._version for _, buffer in model_for_eager.named_buffers()
3997        ]
3998        compile_version_counters_after = [
3999            buffer._version for _, buffer in model_for_compile.named_buffers()
4000        ]
4001
4002        eager_delta = list(
4003            map(operator.sub, eager_version_counters_after, eager_version_counters)
4004        )
4005        compile_delta = list(
4006            map(operator.sub, compile_version_counters_after, compile_version_counters)
4007        )
4008
4009        self.assertEqual(eager_delta, compile_delta)
4010
4011    @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/128198")
4012    def test_buffer_batch_norm(self):
4013        class MyModel(torch.nn.Module):
4014            def __init__(self):
4015                super().__init__()
4016                self.m = torch.nn.BatchNorm1d(100)
4017
4018            def forward(self, x):
4019                return self.m(x)
4020
4021        model_for_eager = MyModel().to(self.device)
4022        model_for_compile = copy.deepcopy(model_for_eager)
4023
4024        eager_version_counters = [
4025            buffer._version for _, buffer in model_for_eager.named_buffers()
4026        ]
4027        compile_version_counters = [
4028            buffer._version for _, buffer in model_for_compile.named_buffers()
4029        ]
4030
4031        compiled_f = torch.compile(model_for_compile, backend="inductor")
4032
4033        inp_ref = torch.ones(20, 100, requires_grad=True, device=self.device)
4034        inp_test = torch.ones(20, 100, requires_grad=True, device=self.device)
4035
4036        out_ref = model_for_eager(inp_ref.clone())
4037        out_test = compiled_f(inp_test.clone())
4038
4039        eager_version_counters_after = [
4040            buffer._version for _, buffer in model_for_eager.named_buffers()
4041        ]
4042        compile_version_counters_after = [
4043            buffer._version for _, buffer in model_for_compile.named_buffers()
4044        ]
4045
4046        eager_delta = list(
4047            map(operator.sub, eager_version_counters_after, eager_version_counters)
4048        )
4049        compile_delta = list(
4050            map(operator.sub, compile_version_counters_after, compile_version_counters)
4051        )
4052
4053        self.assertEqual(eager_delta, compile_delta)
4054
4055    def test_adaptive_avg_pool_with_output_size_0(self):
4056        m1 = nn.AdaptiveAvgPool1d(0)
4057        self.common(m1, (torch.randn(1, 2),))
4058        m2 = nn.AdaptiveAvgPool2d(0)
4059        self.common(m2, (torch.randn(1, 2, 3),))
4060
4061    def test_max_pool2d1(self):
4062        def fn(x):
4063            return aten.max_pool2d_with_indices(x, [3, 3], [2, 2])
4064
4065        self.common(
4066            fn,
4067            (torch.randn(2, 4, 16, 16),),
4068        )
4069
4070    def test_max_pool2d2(self):
4071        def fn(x):
4072            return aten.max_pool2d_with_indices(x, [3, 3], [2, 2])
4073
4074        self.common(
4075            fn,
4076            (torch.randn([16, 64, 55, 55]),),
4077        )
4078
4079    def test_max_pool2d3(self):
4080        def fn(x):
4081            # with padding
4082            return (
4083                aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [1, 1]),
4084                aten.max_pool2d_with_indices(
4085                    x,
4086                    [
4087                        3,
4088                    ],
4089                    [
4090                        2,
4091                    ],
4092                    [
4093                        1,
4094                    ],
4095                ),
4096            )
4097
4098        self.common(
4099            fn,
4100            (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
4101        )
4102
4103    def test_max_pool2d4(self):
4104        def fn(x):
4105            # with padding
4106            return aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [0, 0], [1, 1], True)
4107
4108        self.common(
4109            fn,
4110            (torch.randn([2, 8, 111, 111]),),
4111        )
4112
4113    def test_max_pool2d5(self):
4114        def fn(x):
4115            return aten.max_pool2d_with_indices(x, [3, 3], [])
4116
4117        self.common(
4118            fn,
4119            (torch.randn([16, 64, 55, 55]),),
4120        )
4121
4122    def test_max_pool2d6(self):
4123        # Too big kernel size, use fallback
4124        def fn(x):
4125            return aten.max_pool2d_with_indices(x, [13, 13], [])
4126
4127        torch._inductor.metrics.generated_kernel_count = 0
4128        self.common(
4129            fn,
4130            (torch.randn([16, 64, 55, 55]),),
4131        )
4132        assertGeneratedKernelCountEqual(self, 0)
4133
4134    # From https://github.com/pytorch/pytorch/issues/94775
4135    def test_max_pool2d7(self):
4136        # ceil mode turns on
4137        def fn(x):
4138            return torch.nn.functional.max_pool2d(
4139                x, 1, stride=(2, 2), padding=0, ceil_mode=True
4140            )
4141
4142        self.common(
4143            fn,
4144            (torch.randn([1, 1, 6, 7]),),
4145        )
4146
4147    # From https://github.com/pytorch/pytorch/issues/93384
4148    def test_max_pool2d8(self):
4149        # dialtion is not 1, use fallback
4150        def fn(x):
4151            return aten.max_pool2d_with_indices(x, [3, 2], [2, 1], [1, 1], [1, 2])
4152
4153        torch._inductor.metrics.generated_kernel_count = 0
4154        self.common(
4155            fn,
4156            (torch.randn([2, 2, 3, 6]),),
4157        )
4158        assertGeneratedKernelCountEqual(self, 0)
4159
4160    def test_avg_pool2d1(self):
4161        def fn(x):
4162            return aten.avg_pool2d(x, [3, 3], [2, 2])
4163
4164        self.common(
4165            fn,
4166            (torch.randn(2, 4, 16, 16),),
4167        )
4168
4169    def test_avg_pool2d2(self):
4170        def fn(x):
4171            return aten.avg_pool2d(x, [3, 3], [2, 2])
4172
4173        self.common(
4174            fn,
4175            (torch.randn([16, 64, 55, 55]),),
4176        )
4177
4178    def test_avg_pool2d3(self):
4179        def fn(x):
4180            return (
4181                aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1]),
4182                aten.avg_pool2d(
4183                    x,
4184                    [
4185                        3,
4186                    ],
4187                    [
4188                        2,
4189                    ],
4190                    [
4191                        1,
4192                    ],
4193                ),
4194            )
4195
4196        self.common(
4197            fn,
4198            (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
4199        )
4200
4201    def test_avg_pool2d4(self):
4202        def fn(x):
4203            return aten.avg_pool2d(x, [3, 3], [2, 2], [0, 0], True)
4204
4205        self.common(
4206            fn,
4207            (torch.randn([2, 8, 111, 111]),),
4208        )
4209
4210    def test_avg_pool2d5(self):
4211        def fn(x):
4212            return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], count_include_pad=False)
4213
4214        self.common(
4215            fn,
4216            (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
4217        )
4218
4219    def test_avg_pool2d6(self):
4220        def fn(x):
4221            return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], divisor_override=3)
4222
4223        self.common(
4224            fn,
4225            (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),),
4226        )
4227
4228    def test_avg_pool2d7(self):
4229        # Large kernel size, use fallback
4230        def fn(x):
4231            return aten.avg_pool2d(x, [13, 13], [1, 1], [0, 0])
4232
4233        torch._inductor.metrics.generated_kernel_count = 0
4234        self.common(
4235            fn,
4236            (-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),),
4237        )
4238        assertGeneratedKernelCountEqual(self, 0)
4239
4240    def test_avg_pool2d8(self):
4241        # https://github.com/pytorch/pytorch/issues/100987
4242        def fn(x):
4243            return aten.avg_pool2d(
4244                x, kernel_size=3, stride=2, padding=1, ceil_mode=True
4245            )
4246
4247        self.common(
4248            fn,
4249            (torch.randn(1, 3, 6, 6),),
4250        )
4251
4252    def test_alexnet_prefix(self):
4253        def forward(arg6, arg7, arg16):
4254            convolution = torch.ops.aten.convolution(
4255                arg16, arg7, arg6, [4, 4], [2, 2], [1, 1], False, [0, 0], 1
4256            )
4257            relu = torch.ops.aten.relu(convolution)
4258            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices(
4259                relu, [3, 3], [2, 2]
4260            )
4261            getitem = max_pool2d_with_indices[0]
4262            return (getitem,)
4263
4264        self.common(
4265            forward,
4266            (
4267                rand_strided((64,), (1,), torch.float32, "cpu"),
4268                rand_strided((64, 3, 11, 11), (363, 121, 11, 1), torch.float32, "cpu"),
4269                rand_strided(
4270                    (16, 3, 224, 224), (150528, 50176, 224, 1), torch.float32, "cpu"
4271                ),
4272            ),
4273            # Mismatched elements: 127 / 746496 (0.0%)
4274            # Greatest absolute difference: 0.0009765625 at index (1, 62, 7, 16) (up to 1e-05 allowed)
4275            # Greatest relative difference: 0.05187467899332306 at index (14, 18, 11, 0) (up to 0.001 allowed)
4276            atol=3e-3,
4277            rtol=2,
4278        )
4279
4280    def test_elu(self):
4281        def fn(x):
4282            return aten.elu(x, 1.6732632423543772, 1.0507009873554805) + 2, aten.elu(
4283                x + 1, 2, 3, 4
4284            )
4285
4286        self.common(
4287            fn,
4288            (torch.randn([16, 16]),),
4289        )
4290
4291    def test_tan(self):
4292        def fn(x):
4293            return aten.tan(x) + 2, aten.tan(x + 1)
4294
4295        self.common(
4296            fn,
4297            (torch.randn([16, 16]),),
4298        )
4299
4300    def test_tanh(self):
4301        def fn(x):
4302            return aten.tanh(x) + 2, aten.tanh(x + 1)
4303
4304        self.common(
4305            fn,
4306            (torch.randn([16, 16]),),
4307        )
4308
4309    def test_lgamma(self):
4310        def fn(x):
4311            return aten.lgamma(x) + 2, aten.cos(x + 1)
4312
4313        self.common(
4314            fn,
4315            (torch.randn([16, 16]),),
4316        )
4317
4318    def test_cos(self):
4319        def fn(x):
4320            return aten.cos(x) + 2, aten.cos(x + 1)
4321
4322        self.common(
4323            fn,
4324            (torch.randn([16, 16]),),
4325        )
4326
4327    def test_sin(self):
4328        def fn(x):
4329            return aten.sin(x) + 2, aten.sin(x + 1)
4330
4331        self.common(
4332            fn,
4333            (torch.randn([16, 16]),),
4334        )
4335
4336    def test_repeat(self):
4337        def fn(x):
4338            return (
4339                x.repeat(0, 1, 1, 1),
4340                x.repeat(2, 2, 3, 1),
4341                x.repeat(8, 1, 1, 1),
4342                x.repeat(2, 1, 1, 1, 1, 1),
4343            )
4344
4345        self.common(
4346            fn,
4347            (torch.randn([1, 2, 4, 8]),),
4348        )
4349
4350    def test_repeat_as_strided(self):
4351        # Reproducer for #127474
4352
4353        def fn(x):
4354            view_size = (3, 2)
4355            full = x.repeat((3, 2))
4356            view = torch.as_strided(full, view_size, full.stride())
4357            result = view + view
4358
4359            return result
4360
4361        self.common(fn, (torch.randn(1, 1),))
4362
4363    def test_repeat_interleave(self):
4364        def fn(x):
4365            return (
4366                x.repeat_interleave(2),
4367                x.repeat_interleave(3, dim=0),
4368                x.repeat_interleave(x.size(1), dim=1),
4369            )
4370
4371        self.common(
4372            fn,
4373            (torch.randn([1, 2, 4, 8]),),
4374        )
4375
4376    @config.patch(implicit_fallbacks=True)
4377    def test_repeat_interleave_2(self):
4378        def fn(x):
4379            return torch.ops.aten.repeat_interleave.Tensor(x, output_size=12)
4380
4381        self.common(
4382            fn,
4383            (torch.tensor([2, 4, 6]),),
4384        )
4385
4386    @config.patch(fallback_random=True)
4387    def test_randn_with_dtype_and_device(self):
4388        if self.device == GPU_TYPE:
4389            raise unittest.SkipTest("only support cpu randn_with_dtype_and_device test")
4390
4391        def fn(vectors):
4392            rotations_shape = (12, vectors.shape[-1], 1, 64)
4393            random_rotations = torch.randn(
4394                rotations_shape, device=vectors.device, dtype=vectors.dtype
4395            )
4396            random_rotations += 1
4397            return random_rotations
4398
4399        self.common(
4400            fn,
4401            (torch.randn([4, 12, 2, 64]),),
4402        )
4403
4404    def test_embedding(self):
4405        m = torch.nn.Sequential(
4406            torch.nn.Embedding(10, 4, padding_idx=0),
4407            torch.nn.ReLU(),
4408            ToTuple(),
4409        )
4410
4411        self.common(
4412            m,
4413            (torch.randint(10, [2, 8]),),
4414        )
4415
4416    def test_mean(self):
4417        def fn(x):
4418            return (
4419                x.mean(),
4420                x.mean(-1),
4421                torch.mean(x, -2, keepdim=True),
4422                x.mean([0, 1]),
4423            )
4424
4425        self.common(
4426            fn,
4427            (torch.randn([1, 2, 4, 8]),),
4428        )
4429
4430    def test_var_mean(self):
4431        def fn(x):
4432            return (
4433                *torch.var_mean(x, -1),
4434                *torch.var_mean(x, [1, 3]),
4435            )
4436
4437        self.common(
4438            fn,
4439            (torch.randn([1, 2, 4, 8]),),
4440        )
4441
4442    def test_var_correction(self):
4443        def fn(x):
4444            dim = -1
4445            return (
4446                torch.var(x, dim=dim, correction=1.3),
4447                torch.var(x, dim=dim, correction=3),
4448                torch.var(x, dim=dim, correction=10),
4449            )
4450
4451        self.common(fn, (torch.randn([2, 8]),))
4452        # Unrolled reduction
4453        self.common(fn, (torch.randn([2, 4]),))
4454
4455    @config.patch(pick_loop_orders=True)
4456    def test_transposed_propagates(self):
4457        @torch._dynamo.optimize("inductor", nopython=True)
4458        def fn(x, y):
4459            return x + y
4460
4461        a = torch.randn(1, 4, 4, 4, device=self.device).permute(0, 2, 3, 1)
4462        b = torch.randn(4, 4, 4, device=self.device).permute(1, 2, 0)
4463        c = fn(a, b)
4464        self.assertEqual(a.stride(), c.stride())
4465        self.assertEqual(c.stride()[2], 1)
4466
4467    def test_std(self):
4468        def fn(x):
4469            return (
4470                torch.var(x, True),
4471                torch.var(x, False),
4472                torch.var(x, -1, True),
4473                torch.var(x, -1, False),
4474                torch.std(x, False),
4475                torch.std(x, [0, 1], True),
4476                torch.std(x, [0, 1], False),
4477                torch.std(x, -2, True, keepdim=True),
4478            )
4479
4480        self.common(
4481            fn,
4482            (torch.randn([2, 4, 4, 8]),),
4483        )
4484
4485    def test_embedding_bag(self):
4486        def fn(w, i, o):
4487            return aten._embedding_bag(w, i, o, False, 0, False, None)
4488
4489        self.common(
4490            fn,
4491            (torch.randn([10, 4]), torch.randint(10, [8]), torch.tensor([0, 2, 6])),
4492        )
4493
4494    def test_batch_norm_2d(self):
4495        m = torch.nn.Sequential(
4496            torch.nn.BatchNorm2d(10),
4497            torch.nn.ReLU(),
4498        )
4499        m.eval()
4500        self.common(m, (torch.randn([2, 10, 8, 8]),), check_lowp=False)
4501        self.common(
4502            m,
4503            (torch.randn([3, 10, 16, 16]),),
4504            check_lowp=False,  # too painful to match types of bn model
4505        )
4506
4507    # From yolov3
4508    @with_tf32_off
4509    def test_batch_norm_2d_2(self):
4510        if self.device == "cpu":
4511            raise unittest.SkipTest(f"requires {GPU_TYPE}")
4512
4513        class Repro(torch.nn.Module):
4514            def __init__(self):
4515                super().__init__()
4516                self.self_0 = torch.nn.Conv2d(
4517                    64,
4518                    128,
4519                    kernel_size=(3, 3),
4520                    stride=(2, 2),
4521                    padding=(1, 1),
4522                    bias=False,
4523                )
4524                self.self_1 = torch.nn.BatchNorm2d(
4525                    128,
4526                    eps=0.0001,
4527                    momentum=0.03,
4528                    affine=True,
4529                    track_running_stats=True,
4530                )
4531                self.self_2 = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True)
4532
4533            def forward(self, l_input_: torch.Tensor):
4534                self_0 = self.self_0(l_input_)
4535                self_1 = self.self_1(self_0)
4536                self_2 = self.self_2(self_1)
4537                return (self_2,)
4538
4539        inp = torch.randn((4, 64, 192, 256), dtype=torch.float32, device=GPU_TYPE)
4540        mod = Repro().to(device=GPU_TYPE)
4541        o1 = mod(inp)
4542        o2 = torch.compile(mod)(inp)
4543        self.assertEqual(o1, o2)
4544
4545    @patch.object(config.trace, "enabled", True)
4546    def test_layer_norm(self):
4547        m = torch.nn.Sequential(
4548            torch.nn.LayerNorm(32),
4549            torch.nn.ReLU(),
4550        )
4551        m.eval()
4552        with torch.no_grad():
4553            self.common(m, (torch.randn([16, 32]),), check_lowp=False)
4554        if self.device != "cpu":
4555            assertGeneratedKernelCountEqual(self, 1)
4556
4557    def test_transpose_add(self):
4558        def fn(a, b):
4559            return a.t() + b
4560
4561        self.common(
4562            fn, (torch.randn([16, 32]), torch.randn([32, 16])), check_lowp=False
4563        )
4564        if self.device != "cpu":
4565            assertGeneratedKernelCountEqual(self, 1)
4566
4567    @patch.object(config.triton, "persistent_reductions", True)
4568    def test_softmax_one_kernel_persist(self):
4569        def fn(x):
4570            dim = 1
4571            x_max = torch.amax(x, dim, keepdim=True)
4572            unnormalized = torch.exp(x - x_max)
4573            result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
4574            return result
4575
4576        self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
4577        if self.device != "cpu":
4578            assertGeneratedKernelCountEqual(self, 1)
4579
4580    @patch.object(config.triton, "persistent_reductions", False)
4581    def test_softmax_one_kernel_loop(self):
4582        def fn(x):
4583            x_max = torch.amax(x, 1, keepdim=True)
4584            unnormalized = torch.exp(x - x_max)
4585            result = unnormalized / torch.sum(unnormalized, 1, keepdim=True)
4586            return result
4587
4588        self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
4589        if self.device != "cpu":
4590            assertGeneratedKernelCountEqual(self, 1)
4591
4592    def test_complex_fallback(self):
4593        def fn(x):
4594            return x * x + 10
4595
4596        self.common(
4597            fn,
4598            (torch.randn([1, 2, 4, 8]).to(dtype=torch.complex64),),
4599        )
4600        assertGeneratedKernelCountEqual(self, 0)
4601
4602        class ToComplex(nn.Module):
4603            def forward(self, x):
4604                return (x + x + 12).to(torch.complex64)
4605
4606        self.common(ToComplex(), (torch.rand([1, 2, 4, 8]),), check_lowp=False)
4607
4608        if self.device != "cpu":
4609            assertGeneratedKernelCountEqual(self, 1)
4610
4611    def test_view_as_complex(self):
4612        class Repro(torch.nn.Module):
4613            def __init__(self):
4614                super().__init__()
4615
4616            def forward(self, view_2):
4617                clone = torch.ops.aten.clone.default(
4618                    view_2, memory_format=torch.contiguous_format
4619                )
4620                view_2 = None
4621                view_as_complex = torch.ops.aten.view_as_complex.default(clone)
4622                clone = None
4623                return (view_as_complex,)
4624
4625        inp = torch.empty_strided((128, 64, 12, 32, 2), (1, 98304, 8192, 256, 128)).to(
4626            self.device
4627        )
4628        mod = Repro()
4629
4630        o1 = mod(inp)
4631        o2 = torch.compile(mod)(inp)
4632
4633        self.assertEqual(o1, o2)
4634
4635    def test_view_as_real(self):
4636        def fn(x):
4637            y = torch.view_as_real(x)
4638            return y + 1
4639
4640        x = torch.randn(4, dtype=torch.complex64)
4641
4642        self.common(fn, (x,))
4643
4644    def test_cauchy(self):
4645        def fn(x, y):
4646            return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
4647
4648        self.common(
4649            fn,
4650            (
4651                torch.randn(32),
4652                torch.randn(32),
4653            ),
4654            # Absolute difference: 0.0003662109375 (up to 0.0001 allowed)
4655            # Relative difference: 1.8804297408767818e-05 (up to 1e-05 allowed)
4656            atol=5 * 1e-4,
4657            rtol=5 * 1e-5,
4658            check_lowp=False,
4659        )
4660        if self.device != "cpu":
4661            assertGeneratedKernelCountEqual(self, 1)
4662
4663    def test_fusing_write_into_disjoint_read(self):
4664        def test_flip(a):
4665            return a.copy_(torch.flip(a, (0,)))
4666
4667        self.common(test_flip, (torch.rand([20]),))
4668
4669        assertGeneratedKernelCountEqual(self, 2)
4670
4671        # issue only manifests on cuda with large tensors
4672        if self.device != "cpu":
4673
4674            def f(a):
4675                a[:, 20:40] = a[:, 20:40] + 1
4676                a[:, 2:900025] = a[:, 1:900024] + 2
4677
4678            a = torch.rand((1, 1000000), device=GPU_TYPE)
4679            self.common(f, (a,))
4680
4681    def test_gather_scatter(self):
4682        def fn(node_feat, edge_index):
4683            src_node_feat = node_feat[edge_index[0]]
4684            dst_node_feat = node_feat[edge_index[1]]
4685            edge_feat = src_node_feat - dst_node_feat + 1
4686            new_node_feat = torch.zeros_like(node_feat)
4687            new_node_feat.scatter_add_(
4688                0, edge_index[1].unsqueeze(-1).expand_as(edge_feat), edge_feat
4689            )
4690            return new_node_feat
4691
4692        num_nodes = 16
4693        num_features = 32
4694        node_feat = torch.randn(num_nodes, num_features)
4695        edge_index = torch.randint(0, num_nodes, size=(2, num_nodes * 5))
4696        self.common(
4697            fn,
4698            (
4699                node_feat,
4700                edge_index,
4701            ),
4702            check_lowp=False,
4703        )
4704        if self.device != "cpu":
4705            assertGeneratedKernelCountEqual(self, 2)
4706
4707    @config.patch(max_fusion_size=1)
4708    def test_no_mega_fusion_during_lowering(self):
4709        n = 50
4710
4711        def fn(*args):
4712            x = args[0]
4713            for i in range(n):
4714                x = torch.add(x, args[i])
4715            return x
4716
4717        self.common(
4718            fn,
4719            [torch.randn(64) for _ in range(n)],
4720            check_lowp=False,
4721        )
4722        print("-->", torch._inductor.metrics.generated_kernel_count)
4723        if self.device != "cpu":
4724            self.assertTrue(torch._inductor.metrics.generated_kernel_count > 1)
4725
4726    def test_move_arange(self):
4727        def fn(x):
4728            return torch.arange(len(x), device="cpu").to(x.device) + x
4729
4730        self.common(fn, (torch.randn([32]),), check_lowp=False)
4731        # if we have a copy there will be more than 1 kernel
4732        assertGeneratedKernelCountEqual(self, 1)
4733
4734    def test_leaky_relu(self):
4735        def fn(x):
4736            return aten.leaky_relu(x, 0.2) + 2, aten.leaky_relu(x + 1)
4737
4738        self.common(
4739            fn,
4740            (torch.randn([16, 16]),),
4741        )
4742
4743    def test_gelu(self):
4744        def fn(x):
4745            return aten.gelu(x) + 2, aten.gelu(x + 1)
4746
4747        self.common(
4748            fn,
4749            (torch.randn([16, 16]),),
4750        )
4751
4752    def test_clone(self):
4753        def fn(x):
4754            return aten.clone(x) + 2, aten.clone(x + 1)
4755
4756        self.common(
4757            fn,
4758            (torch.randn([16, 16]),),
4759        )
4760
4761    def test_masked_fill(self):
4762        def fn(mask, value):
4763            return aten.masked_fill(value, mask, -10000.0) + 2, aten.masked_fill(
4764                value / 2.0, torch.logical_not(mask), 667
4765            )
4766
4767        self.common(
4768            fn,
4769            (
4770                torch.randint(0, 1, [1, 16], dtype=torch.bool),
4771                torch.randn([16, 16]),
4772            ),
4773        )
4774
4775    def test_masked_fill_promotion(self):
4776        def fn(mask, value):
4777            return aten.masked_fill(value, mask, torch.tensor(3.5))
4778
4779        opt_fn = torch._dynamo.optimize("inductor")(fn)
4780        for inp in (
4781            torch.randn(
4782                [16, 16],
4783                dtype=torch.float16 if self.device == GPU_TYPE else torch.float32,
4784                device=self.device,
4785            ),
4786            torch.randint(16, (16, 16), device=self.device),
4787        ):
4788            inputs = (
4789                torch.randint(0, 1, [1, 16], dtype=torch.bool, device=self.device),
4790                inp,
4791            )
4792            self.assertEqual(fn(*inputs), opt_fn(*inputs))
4793
4794    def test_masked_scatter(self):
4795        def fn(value, mask, source):
4796            return torch.masked_scatter(value, mask, source)
4797
4798        value = make_tensor(10, 10, dtype=torch.float32, device=self.device)
4799        mask = make_tensor(10, 10, dtype=torch.bool, device=self.device)
4800        source = make_tensor(
4801            mask.count_nonzero(), dtype=torch.float32, device=self.device
4802        )
4803
4804        self.common(fn, (value, mask, source))
4805
4806    def test_fill1(self):
4807        def fn(x):
4808            tmp = torch.ones_like(x)
4809            return tmp, aten.fill.Scalar(tmp, 2)
4810
4811        self.common(
4812            fn,
4813            (torch.randn([16, 16]),),
4814        )
4815
4816    def test_fill2(self):
4817        def fn(x):
4818            tmp = torch.ones_like(x)
4819            return tmp, aten.fill.Tensor(tmp, torch.tensor(3.0))
4820
4821        self.common(
4822            fn,
4823            (torch.randn([16, 16]),),
4824        )
4825
4826    def test_pow1(self):
4827        def fn(x):
4828            return [aten.pow(x, e) for e in range(-8, 9)]
4829
4830        self.common(
4831            fn,
4832            (torch.randn([16, 16]),),
4833        )
4834
4835    def test_pow2(self):
4836        def fn(x):
4837            return aten.pow(1000, x), aten.pow(x, 1000)
4838
4839        self.common(
4840            fn,
4841            (
4842                torch.randn(
4843                    [16, 16],
4844                    dtype=torch.float32,
4845                ),
4846            ),
4847            # Mismatched elements: 9 / 256 (3.5%)
4848            # Greatest absolute difference: 2.491354329061828e+28 at index (6, 6) (up to 1e-05 allowed)
4849            # Greatest relative difference: 2.9793410720160818e-05 at index (4, 5) (up to 1.3e-06 allowed)
4850            atol=1e-5,
4851            rtol=3e-05,
4852        )
4853
4854    def test_pow3(self):
4855        # power of 0.5 is special-cased, arbitrary power would still produce triton codegen error
4856        def fn(x):
4857            z = torch.tensor(0.123, device=self.device)
4858            w = z + x
4859            return torch.pow(w, 0.5)
4860
4861        opt = torch._dynamo.optimize("inductor")(fn)
4862        input = torch.rand(())
4863        self.assertTrue(same(opt(input), fn(input)))
4864
4865    def test_pow_int(self):
4866        def fn(x, y):
4867            return torch.pow(x, 0x57), torch.pow(x, y)
4868
4869        for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
4870            intmax = torch.iinfo(dtype).max
4871            make_arg = functools.partial(
4872                make_tensor, dtype=dtype, device=self.device, requires_grad=False
4873            )
4874            self.common(
4875                fn,
4876                (
4877                    make_arg(16, 16),
4878                    make_arg(16, 16, high=intmax),
4879                ),
4880            )
4881
4882    def test_glu(self):
4883        def fn(x):
4884            return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2)
4885
4886        self.common(
4887            fn,
4888            (torch.randn([8, 16, 8, 8]),),
4889        )
4890
4891    @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
4892    def test_nonzero_unbacked_refinement(self):
4893        def fn(x):
4894            z = x.nonzero()
4895            torch._check(z.size(0) == 4)
4896            return z + 3
4897
4898        self.common(
4899            fn,
4900            (torch.tensor([0, 1, 3, 4, 2, 0, 0]),),
4901        )
4902
4903        with self.assertRaises(RuntimeError):
4904            torch.compile(fn)(torch.tensor([0, 0, 0, 0]))
4905
4906    @torch._dynamo.config.patch(capture_scalar_outputs=True)
4907    def test_unbacked_floordiv_simplify(self):
4908        def fn(x, y):
4909            z = y.item()
4910            torch._check(z // 2 == 3)
4911            return x + x.new_zeros(z)
4912
4913        self.common(
4914            fn,
4915            (
4916                torch.randn(6),
4917                torch.tensor([6]),
4918            ),
4919        )
4920
4921        self.common(
4922            fn,
4923            (
4924                torch.randn(7),
4925                torch.tensor([7]),
4926            ),
4927        )
4928
4929    @torch._dynamo.config.patch(capture_scalar_outputs=True)
4930    def test_unbacked_floordiv_simplify_errors(self):
4931        def fn(x, y):
4932            z = y.item()
4933            torch._check(z // 2 == 3)
4934            return x + x.new_zeros(z)
4935
4936        # This is a little suboptimal: we actually fail /in the compiler/ but
4937        # not in a way that causes Dynamo to graph break
4938        with self.assertRaises(RuntimeError):
4939            torch.compile(fn)(torch.randn(8), torch.tensor(8))
4940
4941    def test_cat(self):
4942        def fn(a):
4943            tmp = a * 2
4944            return (
4945                torch.cat((a, a[:, :4] + 1, a + 2), -1),
4946                torch.cat((tmp, tmp), 0),
4947                torch.cat((tmp, tmp.double()), 0),
4948            )
4949
4950        self.common(
4951            fn,
4952            (torch.randn([8, 16]),),
4953        )
4954        self.common(
4955            fn,
4956            (torch.randn([1, 3, 3, 16]).to(memory_format=torch.channels_last),),
4957        )
4958
4959    def test_cat_uint8(self):
4960        def fn(x):
4961            batch_shape = x.shape[:1]
4962            out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1)
4963            return out
4964
4965        self.common(
4966            fn,
4967            (torch.randint(0, 256, size=(3, 255), dtype=torch.uint8),),
4968        )
4969
4970    def test_cat_empty(self):
4971        def fn_2(*tensors):
4972            return torch.cat(tensors)
4973
4974        self.common(
4975            fn_2,
4976            (
4977                torch.randn([1, 3, 3, 16]),
4978                torch.ones([0]),
4979            ),
4980        )
4981        self.common(
4982            fn_2,
4983            (
4984                torch.randn([1, 3, 3, 16]),
4985                torch.ones([0]),
4986                torch.randn([1, 3, 3, 16]),
4987            ),
4988        )
4989        self.common(
4990            fn_2,
4991            (
4992                torch.ones([0]),
4993                torch.randn([1, 3, 3, 16]),
4994            ),
4995        )
4996
4997    @torch._dynamo.config.patch(capture_scalar_outputs=True)
4998    def test_cat_unbacked_legacy_empty(self):
4999        def fn(x, y):
5000            z = y.item()
5001            return torch.cat([x, x.new_ones(z)])
5002
5003        self.common(
5004            fn,
5005            (
5006                torch.randn([2, 3]),
5007                torch.tensor([0]),
5008            ),
5009        )
5010
5011    @torch._dynamo.config.patch(capture_scalar_outputs=True)
5012    def test_cat_unbacked_empty_1d(self):
5013        def fn(x, y):
5014            z = y.item()
5015            return torch.cat([x, x.new_ones(z)])
5016
5017        self.common(
5018            fn,
5019            (
5020                torch.randn([2]),
5021                torch.tensor([0]),
5022            ),
5023        )
5024
5025        self.common(
5026            fn,
5027            (
5028                torch.randn([2]),
5029                torch.tensor([3]),
5030            ),
5031        )
5032
5033    @torch._dynamo.config.patch(capture_scalar_outputs=True)
5034    def test_cat_unbacked_2d(self):
5035        def fn(x, y):
5036            z = y.item()
5037            return torch.cat([x, x.new_ones(z, x.shape[1])])
5038
5039        self.common(
5040            fn,
5041            (
5042                torch.randn([2, 3]),
5043                torch.tensor([0]),
5044            ),
5045        )
5046
5047        self.common(
5048            fn,
5049            (
5050                torch.randn([2, 3]),
5051                torch.tensor([4]),
5052            ),
5053        )
5054
5055    def test_cat_negative_dim(self):
5056        def fn(*tensors):
5057            return torch.cat(tensors, dim=-1)
5058
5059        self.common(
5060            fn,
5061            (
5062                torch.randn([2, 3]),
5063                torch.randn([2, 4]),
5064            ),
5065        )
5066
5067        self.common(
5068            fn,
5069            (
5070                torch.randn([2, 3]),
5071                torch.randn([0]),
5072                torch.randn([2, 4]),
5073            ),
5074        )
5075
5076        self.common(
5077            fn,
5078            (
5079                torch.randn([0]),
5080                torch.randn([2, 3]),
5081                torch.randn([2, 4]),
5082            ),
5083        )
5084
5085    @expectedFailureCodegenDynamic
5086    def test_cat_single_empty(self):
5087        # fails dynamic check for 'has a dynamic dimension'
5088        def fn_2(*tensors):
5089            return torch.cat(tensors)
5090
5091        self.common(
5092            fn_2,
5093            (torch.ones([0]),),
5094        )
5095
5096    def test_cat_upcasting(self):
5097        def fn(arg4_1, slice_7):
5098            cat_1 = aten.cat.default([arg4_1, slice_7], 1)
5099            return (cat_1,)
5100
5101        self.common(
5102            fn,
5103            (
5104                torch.randn([8, 16], dtype=torch.float32),
5105                torch.randn([8, 20], dtype=torch.float16),
5106            ),
5107        )
5108
5109    def test_cat_extern_kernel(self):
5110        def fn(x1, x2, x3, x4):
5111            x = torch.mm(x2, x3)
5112            s = torch.narrow(x, 1, 0, 100)
5113            x = torch.mm(s, x4)
5114            c = torch.cat((x, x1), 1)
5115            return (c,)
5116
5117        if self.device == "xpu":
5118            atol = 3e-4
5119            rtol = 1e-4
5120        else:
5121            # use default
5122            atol = None
5123            rtol = None
5124        self.common(
5125            fn,
5126            (
5127                torch.randn(256, 256),
5128                torch.randn(256, 1024),
5129                torch.randn(1024, 1600),
5130                torch.randn(100, 256),
5131            ),
5132            atol=atol,
5133            rtol=rtol,
5134            check_lowp=False,  # accuracy issues with relatively large matmuls
5135        )
5136
5137    @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
5138    # Constant folding was explicitly turned off due to issue #108388
5139    # Turn it back on for test
5140    @torch._inductor.config.patch(joint_graph_constant_folding=True)
5141    def test_remove_no_ops(self):
5142        def matmul_with_op(x, y, fn):
5143            return fn(x @ y)
5144
5145        foo_opt = torch.compile(matmul_with_op)
5146
5147        # test no-op
5148        fns = (
5149            lambda x: x
5150            + torch.zeros(
5151                [256, 256], dtype=torch.float32, device=x.device
5152            ),  # noqa: E731
5153            lambda x: x
5154            - torch.zeros(
5155                [256, 256], dtype=torch.float32, device=x.device
5156            ),  # noqa: E731
5157            lambda x: x
5158            * torch.ones(
5159                [256, 256], dtype=torch.float32, device=x.device
5160            ),  # noqa: E731
5161            lambda x: x
5162            / torch.ones(
5163                [256, 256], dtype=torch.float32, device=x.device
5164            ),  # noqa: E731
5165        )
5166
5167        inps = [torch.rand([256, 256], device=self.device) for _ in range(2)]
5168
5169        for fn in fns:
5170            out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn)
5171            self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn))
5172
5173            if self.device == "cpu":
5174                FileCheck().check_not("cpp_fused").run(source_codes[0])
5175            else:
5176                FileCheck().check_not("triton.jit").run(source_codes[0])
5177
5178        # test dtype conversion
5179        inps = [
5180            torch.rand([256, 256], device=self.device, dtype=torch.bfloat16)
5181            for _ in range(2)
5182        ]
5183        for fn in fns:
5184            out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn)
5185            self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn))
5186
5187        # test broadcasted shape bail
5188        fn = lambda x: x + torch.zeros(  # noqa: E731
5189            [256, 256, 256], dtype=torch.bfloat16, device=self.device
5190        )
5191        out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn)
5192        self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn))
5193
5194    def test_remove_noop_copy(self):
5195        def fn(x, y):
5196            x = x.cos()
5197            a = x.copy_(y)
5198            return a.sin()
5199
5200        self.common(fn, (torch.randn(8, 8), torch.randn(8)))
5201
5202        def fn2(a, b):
5203            abs_max = torch.abs(a).max()
5204            b[0] = abs_max.to(a.dtype)
5205            return b
5206
5207        self.common(
5208            fn2,
5209            (
5210                torch.randn(8, 8, dtype=torch.float16),
5211                torch.randn(8, dtype=torch.float32),
5212            ),
5213        )
5214
5215    def test_remove_noop_clone(self):
5216        def fn(x):
5217            y = x.clone().reshape(-1, 4)
5218            y[:, [2, 0]] = y[:, [0, 2]]
5219            return y + x
5220
5221        self.common(fn, (torch.randn(2, 4),))
5222
5223    def test_cat_of_loops_and_extern_kernel(self):
5224        class M(torch.nn.Module):
5225            def __init__(
5226                self,
5227                **kwargs,
5228            ):
5229                super().__init__()
5230                self.conv = torch.nn.Conv2d(
5231                    64,
5232                    5,
5233                    1,
5234                    **kwargs,
5235                )
5236                self.max_pool2d = torch.nn.MaxPool2d(2)
5237
5238            def forward(self, x, y):
5239                x1 = self.conv(x)
5240                y1 = self.max_pool2d(y)
5241                return torch.cat([x1, y1], 1)
5242
5243        mod = M()
5244        opt_mod = torch._dynamo.optimize("inductor")(mod)
5245        memory_format = torch.channels_last
5246        inputs = (
5247            torch.randn([1, 64, 16, 16]).to(memory_format=memory_format),
5248            torch.randn([1, 64, 32, 32]).to(memory_format=memory_format),
5249        )
5250        y = mod(*inputs)
5251        opt_y = opt_mod(*inputs)
5252        self.assertEqual(y, opt_y)
5253        self.assertEqual(y.stride(), opt_y.stride())
5254
5255    def test_cat_inplace(self):
5256        def fn(x):
5257            rt = torch.cat([x])
5258            v = x.sin_()
5259            return rt
5260
5261        # can't use self.common because input is modified inplace
5262        inp = torch.ones(2)
5263        opt_fn = torch.compile(fn)
5264        res = opt_fn(inp.clone())
5265        expected = fn(inp.clone())
5266        self.assertEqual(res, expected)
5267
5268    def test_stack(self):
5269        def fn(a, b):
5270            return torch.stack(
5271                [
5272                    a.expand(12, 16),
5273                    b.expand(12, 16),
5274                ],
5275                2,
5276            )
5277
5278        self.common(fn, (torch.randn([1, 16]), torch.randn([12, 1])))
5279
5280    def test_hardtanh(self):
5281        def fn(x):
5282            return F.hardtanh(x), F.hardtanh(x + 1), F.hardtanh(x - 1)
5283
5284        self.common(
5285            fn,
5286            (torch.randn([64]),),
5287        )
5288
5289    def test_hardsigmoid(self):
5290        def fn(x):
5291            return F.hardsigmoid(x), F.hardsigmoid(x + 3), F.hardsigmoid(x - 3)
5292
5293        self.common(
5294            fn,
5295            (torch.randn([64]),),
5296        )
5297
5298    def test_hardswish(self):
5299        def fn(x):
5300            return F.hardswish(x), F.hardswish(x + 3), F.hardswish(x - 3)
5301
5302        self.common(
5303            fn,
5304            (torch.randn([64]),),
5305        )
5306
5307    def test_rsqrt(self):
5308        def fn(x):
5309            return torch.rsqrt(x), torch.rsqrt(x + 1) - 2
5310
5311        self.common(
5312            fn,
5313            (torch.randn([64]),),
5314        )
5315
5316    def test_expm1(self):
5317        def fn(x):
5318            return torch.expm1(x), torch.expm1(x) * 2
5319
5320        for dtype in (torch.float16, torch.float, torch.double, torch.int, torch.int64):
5321            self.common(
5322                fn,
5323                (torch.randn([64]).to(dtype=dtype),),
5324            )
5325            self.common(
5326                fn,
5327                (torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),),
5328            )
5329
5330    def test_log1p(self):
5331        def fn(x):
5332            return torch.log1p(x), torch.log1p(x) * 2
5333
5334        for dtype in (torch.float16, torch.float, torch.double, torch.int, torch.int64):
5335            self.common(
5336                fn,
5337                (torch.randn([64]).to(dtype=dtype),),
5338            )
5339            self.common(
5340                fn,
5341                (torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),),
5342            )
5343
5344    def test_flip(self):
5345        def fn(x):
5346            return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2
5347
5348        self.common(
5349            fn,
5350            (torch.randn([1, 2, 6, 6]),),
5351        )
5352
5353    def test_signbit(self):
5354        def fn(x):
5355            return torch.signbit(x), ~torch.signbit(-x) & 1
5356
5357        self.common(
5358            fn,
5359            (torch.randn([1, 2, 6, 6]),),
5360        )
5361
5362    def test_sign_dtype(self):
5363        def fn(x):
5364            y = torch.sign(x)
5365            return torch.tanh(y)
5366
5367        self.common(fn, (torch.randn([1, 2, 6, 6]),))
5368
5369    def test_fmod(self):
5370        def fn(a, b):
5371            return torch.fmod(a, b), torch.fmod(3.0 * a, b) - 2.0
5372
5373        shape = [1, 2, 6, 6]
5374        self.common(fn, (torch.randn(shape), torch.randn(shape)))
5375
5376    def test_fmod_zero_dim(self):
5377        def fn(a, b):
5378            return (torch.fmod(a, b),)
5379
5380        self.common(
5381            fn,
5382            (
5383                make_tensor(10, device=self.device, dtype=torch.float32),
5384                make_tensor((), device=self.device, dtype=torch.float32),
5385            ),
5386        )
5387        self.common(
5388            fn,
5389            (
5390                make_tensor((), device=self.device, dtype=torch.float32),
5391                make_tensor(10, device=self.device, dtype=torch.float32),
5392            ),
5393        )
5394
5395    def test_log2(self):
5396        def fn(x):
5397            return torch.log2(x), torch.log2(x + 1) - 2
5398
5399        self.common(
5400            fn,
5401            (torch.randn([64]) + 10,),
5402        )
5403
5404    def test_logsumexp(self):
5405        def fn(x):
5406            return torch.logsumexp(x, -1), torch.logsumexp(x, 0) - 2
5407
5408        self.common(
5409            fn,
5410            (torch.randn([8, 8]) + 10,),
5411        )
5412
5413    def test_log_fp64(self):
5414        def fn(x):
5415            return torch.log(x), torch.log2(x)
5416
5417        self.common(
5418            fn,
5419            (torch.randn([1024], dtype=torch.float64) + 10,),
5420        )
5421
5422    def test_bitwise(self):
5423        def fn(x, y):
5424            return (
5425                torch.bitwise_not(x),
5426                torch.bitwise_or(x, y),
5427                torch.bitwise_xor(x, y),
5428                torch.bitwise_and(x, y),
5429            )
5430
5431        self.common(
5432            fn,
5433            (
5434                torch.randint(0, 2**30, [64], dtype=torch.int32),
5435                torch.randint(0, 2**30, [64], dtype=torch.int32),
5436            ),
5437        )
5438
5439    def test_bitwise2(self):
5440        # again with bool types
5441        def fn(x, y):
5442            return (
5443                torch.bitwise_not(x),
5444                torch.bitwise_or(x, y),
5445                torch.bitwise_xor(x, y),
5446                torch.bitwise_and(x, y),
5447            )
5448
5449        self.common(
5450            fn,
5451            (
5452                torch.randint(0, 2, (2, 20), dtype=torch.bool),
5453                torch.randint(0, 2, (2, 20), dtype=torch.bool),
5454            ),
5455        )
5456
5457    def test_bitwise3(self):
5458        # Repro for https://github.com/pytorch/pytorch/issues/97968
5459        def fn(x, y):
5460            return (
5461                torch.max(torch.bitwise_and(x, y), y),
5462                torch.clamp_max(torch.bitwise_or(x, y), y),
5463                torch.clamp_min(torch.bitwise_xor(x, y), y),
5464            )
5465
5466        self.common(
5467            fn,
5468            (
5469                torch.rand([5, 10, 1]).to(torch.int8),
5470                torch.rand([10, 1]).to(torch.int8),
5471            ),
5472        )
5473
5474    def test_inf(self):
5475        def fn(a):
5476            return a + float("inf"), a + float("-inf"), a * -float("inf")
5477
5478        self.common(fn, (torch.randn(8),))
5479
5480    def test_remainder(self):
5481        def fn(a, b):
5482            return (
5483                torch.remainder(a, b),
5484                torch.remainder(a + 1, b - 1),
5485                torch.remainder(a - 1, b + 1),
5486            )
5487
5488        self.common(fn, (torch.randn(64), torch.randn(64)))
5489
5490    def test_zeros(self):
5491        def fn(a):
5492            return (
5493                a + 1,
5494                torch.zeros(
5495                    (1, 8, 64, 64),
5496                    dtype=torch.float32,
5497                    device=a.device,
5498                ),
5499                torch.zeros(
5500                    1,
5501                    8,
5502                    64,
5503                    64,
5504                    dtype=torch.float32,
5505                    device=a.device,
5506                ),
5507                torch.zeros(2, 3),
5508                a + torch.ones(8, device=a.device),
5509                torch.full((2, 3), 3.1416, device=a.device),
5510            )
5511
5512        self.common(fn, (torch.randn(8),))
5513
5514    def test_new_ones(self):
5515        def fn(a):
5516            return (
5517                aten.new_ones(
5518                    a, [], device=a.device, dtype=6, layout=0, pin_memory=False
5519                ),
5520                aten.new_zeros(
5521                    a, [], device=a.device, dtype=6, layout=0, pin_memory=False
5522                ),
5523            )
5524
5525        self.common(fn, (torch.randn(8),))
5526
5527    def test_full_like(self):
5528        def fn(a):
5529            return torch.full_like(a, 7.777) - 1
5530
5531        self.common(fn, (torch.randn(8),))
5532
5533    def test_full_truncation(self):
5534        def fn(a):
5535            return a + torch.full_like(a, 7.777)
5536
5537        for dtype in all_types():
5538            self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),))
5539
5540    def test_full_boolean(self):
5541        def fn(n):
5542            x = torch.full((1,), n >= 1024, device=self.device)
5543            return x, x + 1
5544
5545        self.common(fn, (1024,))
5546        self.common(fn, (1023,))
5547
5548    def test_index1(self):
5549        def fn(a, b, c):
5550            return aten.index(a, [b, c])
5551
5552        self.common(
5553            fn,
5554            (
5555                torch.randn(8, 8, 12),
5556                torch.tensor([0, 0, 2, 2], dtype=torch.int64),
5557                torch.tensor([3, 4, 4, 3], dtype=torch.int64),
5558            ),
5559        )
5560        self.common(
5561            fn,
5562            (
5563                torch.randn(8, 8, 12),
5564                torch.tensor([[0, 0, 2, 2]], dtype=torch.int64),
5565                torch.tensor([[3], [4], [4], [3]], dtype=torch.int64),
5566            ),
5567        )
5568
5569    def test_index2(self):
5570        def fn(a, b):
5571            return (
5572                aten.index(a, [b]),
5573                aten.index(a, [None, b]),
5574            )
5575
5576        self.common(
5577            fn,
5578            (
5579                torch.randn(8, 8, 8),
5580                torch.tensor([[0, 0, 2, 2]], dtype=torch.int64),
5581            ),
5582        )
5583
5584    def test_index3(self):
5585        def fn(x, ia, ib):
5586            return (x[:, ia, None, ib, 0],)
5587
5588        self.common(
5589            fn,
5590            (
5591                torch.randn(3, 4, 4, 4, 3),
5592                torch.tensor([0, 2, 1], dtype=torch.int64),
5593                torch.tensor([0, 2, 1], dtype=torch.int64),
5594            ),
5595        )
5596
5597    def test_output_strides(self):
5598        def fn(x):
5599            y = x.permute(0, 2, 3, 1).contiguous()
5600            torch._dynamo.graph_break()
5601            return y.view(-1, 4)
5602
5603        inp = torch.rand([4, 4, 4, 4], device=self.device)
5604        fn_opt = torch._dynamo.optimize("inductor")(fn)
5605
5606        self.assertEqual(fn(inp), fn_opt(inp))
5607        self.assertEqual(fn(inp).stride(), fn_opt(inp).stride())
5608
5609        # no redundant copy
5610        def foo(x):
5611            return x[0:2:2].T[3:].squeeze(0)
5612
5613        foo_opt = torch._dynamo.optimize("inductor")(foo)
5614        out = foo_opt(inp)
5615        self.assertEqual(inp.storage(), out.storage())
5616
5617    def test_index_select(self):
5618        def fn(a, b):
5619            return (
5620                torch.index_select(a, 0, b),
5621                torch.index_select(a, 1, b),
5622                torch.index_select(torch.index_select(a, 2, b), 1, b),
5623            )
5624
5625        for ind_dtype in (torch.int32, torch.int64):
5626            self.common(
5627                fn,
5628                (
5629                    torch.randn(8, 8, 8),
5630                    torch.tensor([0, 0, 2, 1], dtype=ind_dtype),
5631                ),
5632            )
5633
5634    @skipCUDAIf(not TEST_CUDNN, "CUDNN not available")
5635    @skipIfXpu
5636    @skipIfRocm
5637    def test_cudnn_rnn(self):
5638        if self.device == "cpu":
5639            raise unittest.SkipTest(f"requires {GPU_TYPE}")
5640
5641        def fn(
5642            a0,
5643            b0,
5644            b1,
5645            b2,
5646            b3,
5647            b4,
5648            b5,
5649            b6,
5650            b7,
5651            b8,
5652            b9,
5653            b10,
5654            b11,
5655            b12,
5656            b13,
5657            b14,
5658            b15,
5659            a3,
5660            a4,
5661            a5,
5662        ):
5663            a1 = [
5664                b0,
5665                b1,
5666                b2,
5667                b3,
5668                b4,
5669                b5,
5670                b6,
5671                b7,
5672                b8,
5673                b9,
5674                b10,
5675                b11,
5676                b12,
5677                b13,
5678                b14,
5679                b15,
5680            ]
5681            return aten._cudnn_rnn(
5682                a0,
5683                a1,
5684                4,
5685                a3,
5686                a4,
5687                a5,
5688                2,
5689                2048,
5690                0,
5691                2,
5692                False,
5693                0.0,
5694                False,
5695                True,
5696                [],
5697                None,
5698            )
5699
5700        self.common(
5701            fn,
5702            (
5703                torch.randn([92, 8, 2048]),
5704                torch.randn([8192, 2048]),
5705                torch.randn([8192, 2048]),
5706                torch.randn([8192]),
5707                torch.randn([8192]),
5708                torch.randn([8192, 2048]),
5709                torch.randn([8192, 2048]),
5710                torch.randn([8192]),
5711                torch.randn([8192]),
5712                torch.randn([8192, 4096]),
5713                torch.randn([8192, 2048]),
5714                torch.randn([8192]),
5715                torch.randn([8192]),
5716                torch.randn([8192, 4096]),
5717                torch.randn([8192, 2048]),
5718                torch.randn([8192]),
5719                torch.randn([8192]),
5720                torch.randn([167837696]),
5721                torch.randn([4, 8, 2048]),
5722                torch.randn([4, 8, 2048]),
5723            ),
5724            check_lowp=False,  # difference in rnn is too large between half and float inputs
5725        )
5726
5727    def test_upsample_nearest1d(self):
5728        def fn(a):
5729            return (
5730                aten.upsample_nearest1d(a, [74], None),
5731                aten.upsample_nearest1d(a, [70], None),
5732                aten.upsample_nearest1d(a, [45], None),
5733                aten.upsample_nearest1d(a, [36], None),
5734                aten.upsample_nearest1d(a, None, [2.0]),
5735            )
5736
5737        self.common(fn, (torch.randn([2, 4, 37]),))
5738
5739    def test_upsample_nearest2d(self):
5740        def fn(a):
5741            return (
5742                aten.upsample_nearest2d(a, [74, 76]),
5743                aten.upsample_nearest2d(a, [70, 75]),
5744                aten.upsample_nearest2d(a, [45, 74]),
5745                aten.upsample_nearest2d(a, [36, 39]),
5746                aten.upsample_nearest2d(a, None, [2.0, 2.0]),
5747            )
5748
5749        self.common(fn, (torch.randn([2, 4, 37, 38]),))
5750
5751    def test_upsample_nearest3d(self):
5752        def fn(a):
5753            return (
5754                aten.upsample_nearest3d(a, [74, 76, 78], None),
5755                aten.upsample_nearest3d(a, [70, 75, 80], None),
5756                aten.upsample_nearest3d(a, [45, 74, 103], None),
5757                aten.upsample_nearest3d(a, [36, 39, 40], None),
5758                aten.upsample_nearest3d(a, None, [2.0, 2.0, 2.0]),
5759            )
5760
5761        self.common(fn, (torch.randn([2, 4, 37, 38, 39]),))
5762
5763    def test_upsample_nearest2d_backward(self):
5764        func = torch.ops.aten.upsample_nearest2d_backward
5765
5766        def fn(a):
5767            return (
5768                func(a, output_size=[6, 12], input_size=[3, 3, 3, 6]),
5769                func(a, output_size=[6, 12], input_size=[3, 3, 4, 5]),
5770                func(a, output_size=[6, 12], input_size=[3, 3, 2, 8]),
5771                func(a, output_size=[6, 12], input_size=[3, 3, 2, 8]),
5772                func(a, output_size=[6, 12], input_size=[3, 3, 4, 7]),
5773            )
5774
5775        self.common(fn, (torch.randn([3, 3, 6, 12]),))
5776
5777    @skip_if_x86_mac()
5778    def test_upsample_bilinear2d_a(self):
5779        def fn(a):
5780            return (
5781                aten.upsample_bilinear2d(a, [45, 45], False, None),
5782                aten.upsample_bilinear2d(a, None, True, [2.0, 2.0]),
5783            )
5784
5785        self.common(fn, (torch.randn([2, 4, 37, 38]),), atol=2.5e-5, rtol=1.3e-6)
5786
5787    def test_upsample_bilinear2d_b(self):
5788        def fn(a):
5789            return aten.upsample_bilinear2d(a, None, True, [2.0, 2.0])
5790
5791        self.common(
5792            fn,
5793            [
5794                torch.randn([1, 2, 40, 59]),
5795            ],
5796            atol=2.5e-5,
5797            rtol=1.3e-6,
5798        )
5799
5800    def test_reflection_pad2d(self):
5801        def fn(a, pad):
5802            return (
5803                aten.reflection_pad2d(a, [1, 1, 1, 1]),
5804                aten.reflection_pad2d(a, pad),
5805            )
5806
5807        self.common(
5808            fn,
5809            (
5810                torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),
5811                [5, 2, 3, 4],
5812            ),
5813        )
5814
5815    def test_reflection_pad2d_backward(self):
5816        def template(size, padding):
5817            def fn(grad_output, x):
5818                return aten.reflection_pad2d_backward(grad_output, x, padding)
5819
5820            x = torch.randint(0, 999, size=size, dtype=torch.float32)
5821            result = aten.reflection_pad2d(x, padding)
5822            grad_output = torch.randn_like(result)
5823
5824            self.common(fn, (grad_output, x))
5825
5826        template([1, 1, 8, 8], [0, 0, 0, 0])
5827        template([1, 1, 8, 8], [1, 1, 1, 1])
5828        template([1, 1, 8, 8], [1, 2, 3, 4])
5829        template([1, 1, 8, 8], [0, -1, 2, 2])
5830        template([1, 1, 8, 8], [-1, 0, 2, 2])
5831        template([1, 1, 8, 8], [2, 2, 0, -1])
5832        template([1, 1, 8, 8], [2, 2, -1, 0])
5833
5834    def test_grid_sampler_2d(self):
5835        def fn(a, b):
5836            return (
5837                aten.grid_sampler_2d(a, b, 0, 0, True),
5838                aten.grid_sampler_2d(a, b, 0, 1, False),
5839            )
5840
5841        self.common(
5842            fn,
5843            (
5844                torch.randn([4, 3, 352, 352], dtype=torch.float32),
5845                torch.rand([4, 352, 352, 2], dtype=torch.float32) * 2 - 1,
5846            ),
5847            check_lowp=False,
5848            # Mismatched elements: 154697 / 1486848 (10.4%)
5849            # Greatest absolute difference: 0.0001976490020751953 at index (0, 0, 101, 243) (up to 1e-05 allowed)
5850            # Greatest relative difference: 7.332530120481928 at index (1, 1, 258, 301) (up to 1.3e-06 allowed)
5851            atol=0.0002,
5852            rtol=1.3e-06,
5853        )
5854
5855    def test_upsample_bicubic2d(self):
5856        def fn(a):
5857            return (
5858                aten.upsample_bicubic2d(a, (128, 128), True),
5859                aten.upsample_bicubic2d(a, (128, 256), False),
5860            )
5861
5862        # Mismatched elements: 10 / 196608 (0.0%)
5863        # Greatest absolute difference: 1.3869255781173706e-05 at index (2, 1, 88, 65) (up to 1e-05 allowed)
5864        # Greatest relative difference: 0.0033082996811011046 at index (3, 1, 88, 91) (up to 1.3e-06 allowed)
5865        self.common(
5866            fn,
5867            (torch.randn([4, 3, 64, 32], dtype=torch.float32),),
5868            atol=2e-5,
5869            rtol=1e-3,
5870        )
5871
5872    def test_float_index_expression(self):
5873        # Test that index propagation doesn't generate bad index_expr calls like
5874        # ops.index_expr(0.5*x, dtype) where the expression is not integral
5875        def fn(x):
5876            return aten.upsample_bicubic2d(x, (256, 256), False)
5877
5878        x = torch.randn(1, 1, 128, 128, dtype=torch.float32, device=self.device)
5879        _, source_codes = run_and_get_code(fn, x)
5880
5881        pattern = r"0\.50*\*[ix][\d]"
5882        for code in source_codes:
5883            self.assertIsNone(
5884                re.search(pattern, code), msg="Found bad index_expr in code:\n" + code
5885            )
5886
5887    def test_float_index_expression_type_promotion(self):
5888        # Test that float indexing expressions participate in type promotion
5889        def fn(x):
5890            return x + 1.0 / x.size(0)
5891
5892        x = torch.arange(10)
5893        self.common(fn, (x,))
5894
5895    def test_sort(self):
5896        def fn(a):
5897            return torch.sort(a)
5898
5899        self.common(
5900            fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
5901        )
5902
5903    def test_topk(self):
5904        def fn(a):
5905            return torch.topk(a, 2, -1)
5906
5907        self.common(
5908            fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
5909        )
5910
5911    def test_long_tensor(self):
5912        def fn(a):
5913            return (
5914                torch.LongTensor([294]).to(a.device) - a,
5915                torch.as_tensor([295]).to(a.device) + a,
5916            )
5917
5918        self.common(fn, (torch.randint(0, 999, size=[8, 8]),))
5919
5920    def test_constant_pad_1d(self):
5921        def fn(a):
5922            return (
5923                aten.constant_pad_nd(a, [0, 1], 6.0),
5924                aten.constant_pad_nd(a, [2, 3], 99.0),
5925            )
5926
5927        self.common(fn, (torch.randint(0, 999, size=[2, 16, 31], dtype=torch.float32),))
5928
5929    def test_constant_pad_fill_dtype(self):
5930        def fn(a, b):
5931            return (
5932                aten.constant_pad_nd(a, (1, 1), 1.0) & b,
5933                aten.constant_pad_nd(a, (1, 1), 0.0) & b,
5934            )
5935
5936        self.common(
5937            fn,
5938            (torch.randint(2, (4,), dtype=torch.bool), torch.ones(6, dtype=torch.bool)),
5939        )
5940
5941    def test_constant_pad_2d(self):
5942        def fn(a):
5943            return (
5944                aten.constant_pad_nd(a, [1, 1, 1, 1], 6.0),
5945                aten.constant_pad_nd(a, [1, 2, 3, 4], 99.0),
5946            )
5947
5948        self.common(
5949            fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),)
5950        )
5951
5952    def test_constant_pad_3d(self):
5953        def fn(a):
5954            return (
5955                aten.constant_pad_nd(a, [1, 2, 3, 4, 5, 6], 6.0),
5956                aten.constant_pad_nd(a, [0, 0, 3, 4, 0, 0], 6.0),
5957            )
5958
5959        self.common(
5960            fn, (torch.randint(0, 999, size=[2, 4, 4, 4], dtype=torch.float32),)
5961        )
5962
5963    def test_constant_pad_float64(self):
5964        # Repro for https://github.com/pytorch/pytorch/issues/93351
5965        def fn(input):
5966            v1 = torch.nn.functional.pad(input, pad=(1, 0))
5967            return torch.gt(v1, input)
5968
5969        x = torch.rand([1, 2, 2, 1], dtype=torch.float64)
5970        self.common(fn, (x,))
5971
5972    def test_constant_pad_nd_inplace(self):
5973        def fn(a):
5974            return aten.constant_pad_nd(a, [0, 0])
5975
5976        x = torch.randn([2], device=self.device)
5977        fn_compiled = torch.compile(fn)
5978        y = fn_compiled(x)
5979        self.assertTrue(y is not x)
5980
5981    def test_l1_loss(self):
5982        def fn(a, b):
5983            return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b)
5984
5985        self.common(
5986            fn,
5987            (
5988                torch.randn([2, 3, 16, 16]),
5989                torch.randn([2, 3, 16, 16]),
5990            ),
5991            check_lowp=False,
5992        )
5993
5994    def test_triu(self):
5995        def fn(a):
5996            return aten.triu(a, 1), aten.triu(a, 0), aten.triu(a, 2)
5997
5998        self.common(fn, (torch.randn([2, 10, 10]),))
5999
6000    def test_no_op_reduction(self):
6001        def fn(a):
6002            return a.sum(-1), torch.amax(a + 1, 1, keepdim=True)
6003
6004        self.common(fn, (torch.randn([8, 1, 1]),))
6005
6006    def test_inplace_add(self):
6007        @torch._dynamo.optimize("inductor")
6008        def fn(x, y):
6009            return x.add_(y)
6010
6011        inputs = (
6012            rand_strided((4, 4), (4, 1), device=self.device),
6013            rand_strided((4, 4), (4, 1), device=self.device),
6014        )
6015        inp_clone = inputs[0].clone()
6016        out = fn(*inputs)
6017        self.assertTrue(same(out, inp_clone + inputs[1]))
6018        self.assertTrue(out is inputs[0])
6019
6020    # The following 2 tests are meant to check the logic that drops
6021    # xmask from triton load/store if xnumel = 1
6022    @requires_gpu()
6023    def test_single_elem(self):
6024        def fn(a):
6025            b = a + 1
6026            return (b,)
6027
6028        self.common(fn, (torch.randn(1),))
6029
6030    @requires_gpu()
6031    def test_single_elem_indirect(self):
6032        def fn(a, b):
6033            c = a[b] + 1
6034            return (c,)
6035
6036        a = torch.randn(1)
6037        b = (torch.tensor([0], dtype=torch.int64),)
6038
6039        self.common(fn, (a, b))
6040
6041    # This test is meant to check for issues from the logic
6042    # that drops xmask from trito load/store if XBLOCK divides xnumel
6043
6044    @requires_gpu()
6045    def test_xblock_divides_xnumel(self):
6046        def fn(a):
6047            b = a + 1
6048            return (b,)
6049
6050        # assumption is that XBLOCK is always a divisor of 1024
6051        # so xmask will be dropped iff xnumel is multiple of 1024
6052        self.common(fn, (torch.randn(1024),))
6053        self.common(fn, (torch.randn(1025),))
6054
6055    def test_inplace_mixed_dtype_ops(self):
6056        @torch._dynamo.optimize("inductor")
6057        def fn(x, y):
6058            z = x + y.float()
6059            w = z.add_(y)
6060            return w.mul_(y)
6061
6062        inputs = (
6063            rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.float),
6064            rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.double),
6065        )
6066        out = fn(*inputs)
6067        out_eager = (inputs[0] + inputs[1].float()).add_(inputs[1]).mul_(inputs[1])
6068        self.assertTrue(same(out, out_eager))
6069
6070    @config.patch(
6071        {"triton.unique_kernel_names": True, "triton.descriptive_names": False}
6072    )
6073    def test_kernel_names(self):
6074        @torch._dynamo.optimize("inductor")
6075        def fn(x):
6076            return 2 * x
6077
6078        inputs = (rand_strided((8,), (1,), device=self.device),)
6079        self.assertTrue(same(fn(*inputs), 2 * inputs[0]))
6080
6081    @config.patch({"triton.cudagraphs": True})
6082    @dynamo_config.patch(automatic_dynamic_shapes=True)
6083    def test_strided_inputs(self):
6084        @torch._dynamo.optimize("inductor")
6085        def fn(x, y):
6086            return x + y
6087
6088        inputs = (
6089            rand_strided((8, 16), (32, 2), device=self.device),
6090            rand_strided((8, 16), (16, 1), device=self.device),
6091        )
6092        self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
6093
6094    @config.patch({"triton.cudagraphs": True})
6095    @dynamo_config.patch(automatic_dynamic_shapes=True)
6096    def test_input_mutation1(self):
6097        def fn(a):
6098            b = a + 1
6099            a.copy_(b)
6100            c = a + 2
6101            return a * b / c
6102
6103        arg1 = torch.randn(64, device=self.device)
6104        arg2 = arg1.clone()
6105        arg3 = torch.randn(64, device=self.device)
6106        arg4 = arg3.clone()
6107        correct1 = fn(arg1)
6108        correct2 = fn(arg3)
6109        opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
6110        actual1 = opt_fn(arg2)
6111        actual2 = opt_fn(arg4)
6112
6113        self.assertTrue(same(actual1, correct1))
6114        self.assertTrue(same(actual2, correct2))
6115        self.assertTrue(same(arg1, arg2))
6116        self.assertTrue(same(arg3, arg4))
6117
6118    def test_input_mutation2(self):
6119        def fn(a):
6120            b = a + 1
6121            a.view(64).copy_(torch.tensor([66.0], device=a.device))
6122            c = a + 2
6123            return b, c
6124
6125        # NOTE: this test fails when none of the inputs require grad.
6126        # That seems like an inductor bug.
6127        arg1 = torch.randn([1, 64], device=self.device).requires_grad_(True).add(1)
6128        arg2 = arg1.clone()
6129        correct1 = fn(arg1)
6130        opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
6131        actual1 = opt_fn(arg2)
6132
6133        self.assertTrue(same(actual1, correct1))
6134        self.assertTrue(same(arg1, arg2))
6135
6136    def test_input_mutation3(self):
6137        def fn(a):
6138            a += 1
6139            a *= 2
6140            aten.sigmoid_(a)
6141            a = a.view(64)
6142            a += 3
6143            a *= 4
6144            aten.relu_(a)
6145            return a
6146
6147        arg1 = torch.randn([1, 64], device=self.device)
6148        arg2 = arg1.clone()
6149        correct1 = fn(arg1)
6150        opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
6151        actual1 = opt_fn(arg2)
6152
6153        self.assertTrue(same(actual1, correct1))
6154        self.assertTrue(same(arg1, arg2))
6155
6156    def test_input_mutation4(self):
6157        def fn(a):
6158            torch.relu_(a)
6159            return a
6160
6161        arg1 = torch.randn([1, 64], device=self.device)
6162        arg2 = arg1.clone()
6163        correct1 = fn(arg1)
6164        opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
6165        actual1 = opt_fn(arg2)
6166
6167        self.assertTrue(same(actual1, correct1))
6168        self.assertTrue(same(arg1, arg2))
6169
6170    def test_input_mutation5(self):
6171        def fn(x):
6172            tmp = x.ceil()
6173            x.add_(10)
6174            return tmp
6175
6176        opt_fn = torch._dynamo.optimize()(fn)
6177
6178        a = torch.zeros((), dtype=torch.int64, device=self.device)
6179        a_expect = a.clone()
6180        expect = fn(a_expect)
6181
6182        a_actual = a.clone()
6183        actual = opt_fn(a_actual)
6184
6185        self.assertEqual(a_expect, a_actual)
6186        self.assertEqual(expect, actual)
6187
6188    def test_slice_mutation1(self):
6189        def fn(a):
6190            x = torch.zeros_like(a)
6191            b = x + 1
6192            x[:, 3] = 3.0
6193            c = torch.clone(x)
6194            x[4, :] = 4.0
6195            d = x + 1
6196            return x, b, c, d
6197
6198        self.common(fn, (torch.randn([8, 8]),))
6199
6200    def test_slice_mutation2(self):
6201        def fn(a):
6202            a[:, 20:40] = a[:, 20:40] + 1
6203            a[:, 2:11] = a[:, 1:10] + 2
6204
6205        arg1 = torch.randn([1, 64], device=self.device)
6206        arg2 = arg1.clone()
6207        fn(arg1)
6208        opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
6209        opt_fn(arg2)
6210        self.assertTrue(same(arg1, arg2))
6211
6212    def test_slice_mutation3(self):
6213        def fn(a):
6214            a[:2, :2].fill_(10)
6215
6216        opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
6217
6218        x1 = torch.randn(8, 8, device=self.device)
6219        x2 = x1.clone()
6220        fn(x1)
6221        opt_fn(x2)
6222        self.assertEqual(x1, x2)
6223
6224    def test_tensor_index_slice(self):
6225        def fn(a):
6226            x = torch.tensor([1, 2], device=self.device)
6227            y = torch.tensor([2, 3], device=self.device)
6228            xx = torch.tensor([1, 2], device=self.device).view(1, 2)
6229            yy = torch.tensor([1, 2, 3], device=self.device).view(3, 1)
6230            return [
6231                a[x, y],
6232                a[:, x, y],
6233                a[:, x, y, :],
6234                a[x, :, y],
6235                a[:, x, :, y, :],
6236                a[xx, yy],
6237                a[:, xx, yy],
6238                a[xx, :, yy],
6239                a[xx, yy, :],
6240                a[:, xx, :, yy],
6241            ]
6242
6243        a = torch.arange(3 * 4 * 5 * 6 * 7, device=self.device).view(3, 4, 5, 6, 7)
6244        refs = fn(a)
6245        tests = torch.compile(fn)(a)
6246        for ref, test in zip(refs, tests):
6247            torch.testing.assert_close(ref, test)
6248
6249    @torch._dynamo.config.patch(cache_size_limit=10)
6250    def test_tensor_index_put_slice(self):
6251        def fn(a, version):
6252            x = torch.tensor([1, 2], device=self.device, dtype=torch.int32)
6253            y = torch.tensor([2, 3], device=self.device, dtype=torch.int32)
6254
6255            xx = torch.tensor([1, 2], device=self.device).view(1, 2)
6256            yy = torch.tensor([1, 2, 3], device=self.device).view(3, 1)
6257
6258            if version == 0:
6259                a[x, y] = torch.zeros_like(a[x, y])
6260            elif version == 1:
6261                a[:, x, y] = torch.zeros_like(a[:, x, y])
6262            elif version == 2:
6263                a[:, x, y, :] = torch.zeros_like(a[:, x, y, :])
6264            elif version == 3:
6265                a[x, :, y] = torch.zeros_like(a[x, :, y])
6266            elif version == 4:
6267                a[:, x, :, y, :] = torch.zeros_like(a[:, x, :, y, :])
6268            elif version == 5:
6269                a[xx, yy] = torch.zeros_like(a[xx, yy])
6270            elif version == 6:
6271                a[:, xx, yy] = torch.zeros_like(a[:, xx, yy])
6272            elif version == 7:
6273                a[xx, :, yy] = torch.zeros_like(a[xx, :, yy])
6274            elif version == 8:
6275                a[xx, yy, :] = torch.zeros_like(a[xx, yy, :])
6276            elif version == 9:
6277                a[:, xx, :, yy] = torch.zeros_like(a[:, xx, :, yy])
6278
6279            return a
6280
6281        a = torch.arange(3 * 4 * 5 * 6 * 7, device=self.device, dtype=torch.int32).view(
6282            3, 4, 5, 6, 7
6283        )
6284        for i in range(10):
6285            ref = fn(torch.clone(a), i)
6286            test = torch.compile(fn)(torch.clone(a), i)
6287            torch.testing.assert_close(ref, test)
6288
6289    def test_indirect_load_broadcast(self):
6290        def fn(in_ptr0, in_ptr1, in_ptr2):
6291            return torch.gather(in_ptr1, 0, in_ptr2) + in_ptr0
6292
6293        arg190 = rand_strided((32, 21), (1, 32), device=self.device, dtype=torch.int64)
6294        arg190.fill_(0)
6295        arg111 = rand_strided(
6296            (9521, 512), (512, 1), device=self.device, dtype=torch.float32
6297        )
6298        self.common(
6299            fn,
6300            (
6301                torch.randn(32, 1),
6302                arg111,
6303                arg190,
6304            ),
6305        )
6306
6307    def test_roi_align(self):
6308        if not has_torchvision_roi_align():
6309            raise unittest.SkipTest("requires torchvision")
6310
6311        def fn(a, b):
6312            return torch.ops.torchvision.roi_align(a, b, 0.25, 7, 7, 2, False)
6313
6314        self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5])))
6315
6316    def test_nll_loss_forward(self):
6317        def fn(a, b):
6318            return aten.nll_loss_forward(a, b, None, 1, -100)
6319
6320        labels = (
6321            torch.zeros([5], dtype=torch.int64),
6322            torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
6323        )
6324        inps = (torch.randn(5, 5), torch.randn(5, 5))
6325        for a, b in zip(inps, labels):
6326            self.common(
6327                fn,
6328                (a, b),
6329            )
6330
6331    @skipIfXpu
6332    def test_nll_loss_backward(self):
6333        def fn(a, b, c):
6334            return aten.nll_loss_backward(
6335                a, b, c, None, 1, -100, torch.tensor(1.0, device=self.device)
6336            )
6337
6338        labels = (
6339            torch.zeros([5], dtype=torch.int64),
6340            torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
6341        )
6342        inps = (torch.randn(5, 5), torch.randn(5, 5))
6343        grad_outs = (torch.randn(()), torch.randn(()))
6344        for a, b, c in zip(grad_outs, inps, labels):
6345            self.common(
6346                fn,
6347                (a, b, c),
6348            )
6349
6350    def test_isinf(self):
6351        def fn(x):
6352            return x.isinf(), x.isnan()
6353
6354        self.common(
6355            fn, [torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")])]
6356        )
6357        self.common(
6358            fn,
6359            [
6360                torch.tensor(
6361                    [1, float("inf"), 2, float("-inf"), float("nan")],
6362                    dtype=torch.float64,
6363                )
6364            ],
6365        )
6366
6367    def test_isinf2(self):
6368        def fn(x):
6369            y = torch.tensor(
6370                [1, float("inf"), 2, float("-inf"), float("nan")], device=self.device
6371            )
6372            return x == y
6373
6374        self.common(
6375            fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),)
6376        )
6377
6378    def test_any(self):
6379        def fn(x):
6380            return (
6381                x.any(-1),
6382                x.isinf().any(),
6383                torch.all(x.isinf(), dim=0),
6384                torch.all(torch.logical_not(x.isinf())),
6385            )
6386
6387        self.common(fn, [-torch.rand(64)])
6388        tmp = torch.randn(16, 8)
6389        tmp[1, 1] = float("inf")
6390        self.common(fn, [tmp])
6391
6392    def test_multilayer_any(self):
6393        def fn(x):
6394            return (x.isinf().any(), x.isfinite().all())
6395
6396        sample = torch.rand(9, 3, 353, 353)
6397        self.common(fn, [sample])
6398
6399        sample.view(-1)[-1] = float("inf")
6400        self.common(fn, [sample])
6401
6402    def test_inplace_activations(self):
6403        def fn(x):
6404            a = aten.hardswish_(x + 1)
6405            b = aten.hardtanh_(x + 1)
6406            c = aten.leaky_relu_(x + 1)
6407            d = aten.silu_(x + 1)
6408            e = aten.log1p(x + 1)
6409            f = aten.masked_fill_(x + 1, torch.zeros_like(x, dtype=torch.bool), 99.0)
6410            h = aten.masked_fill_(x + 1, torch.ones_like(x, dtype=torch.bool), 99.0)
6411            return (a, b, c, d, e, f, h)
6412
6413        self.common(fn, [torch.randn(64) * 10])
6414
6415    def test_baddbmm(self):
6416        def fn(a, b, c, beta):
6417            return aten.baddbmm(a, b, c, beta=beta)
6418
6419        b = torch.randn(6, 128, 64)
6420        c = torch.randn(6, 64, 100)
6421        options = itertools.product(
6422            [torch.randn(6, 1, 100), torch.randn(6, 1, 100).fill_(torch.nan)],
6423            [0.0, 1.0],
6424        )
6425        for a, beta in options:
6426            self.common(
6427                fn,
6428                [a, b, c, beta],
6429                # Mismatched elements: 1212 / 76800 (1.6%)
6430                # Greatest absolute difference: 0.001953125 at index (0, 0, 93) (up to 1e-05 allowed)
6431                # Greatest relative difference: 1.0 at index (3, 19, 4) (up to 0.001 allowed)
6432                atol=0.002,
6433                rtol=0.001,
6434            )
6435
6436    @config.patch({"triton.max_tiles": 2})
6437    def test_fuse_tiled(self):
6438        def fn(a, b, c):
6439            return a + b, c + 1
6440
6441        self.common(
6442            fn, [torch.randn(128, 1), torch.randn(1, 128), torch.randn(128, 128)]
6443        )
6444
6445    def test_expand_as(self):
6446        def fn(a, b):
6447            return aten.expand_as(a, b), aten.expand_as(a + 1, b + 1) + 1
6448
6449        self.common(
6450            fn,
6451            [
6452                torch.randn(6, 1, 100),
6453                torch.randn(6, 128, 100),
6454            ],
6455        )
6456
6457    def test_index_put1(self):
6458        def fn(a, b, c):
6459            return (
6460                torch.index_put(a, [b], c),
6461                torch.index_put_(a + 1, [b + 1], c + 1) + 1,
6462            )
6463
6464        self.common(
6465            fn,
6466            [
6467                torch.randn([800, 256, 7, 7]),
6468                torch.randperm(601),
6469                torch.randn([601, 256, 7, 7]),
6470            ],
6471        )
6472        self.common(
6473            fn, [torch.randn(1024, 4, 2), torch.arange(4), torch.randn(4, 1, 1)]
6474        )
6475
6476    def test_index_put2(self):
6477        def fn(a, b, c):
6478            return torch.index_put(a, [b], c, True)
6479
6480        self.common(
6481            fn,
6482            [
6483                torch.randn([100, 256, 7, 7]),
6484                torch.randint(0, 100, size=[600], dtype=torch.int64),
6485                torch.randn([600, 256, 7, 7]),
6486            ],
6487            # workaround for https://github.com/openai/triton/issues/558
6488            check_lowp=False,
6489        )
6490
6491    def test_index_put3(self):
6492        def fn(a, b, c):
6493            torch.ops.aten.index_put_(a, (None, b, None), c)
6494            a1 = a + 1
6495            torch.ops.aten.index_put_(a1, (None, b + 1, None), c + 1)
6496            return (a, a1)
6497
6498        self.common(
6499            fn,
6500            [
6501                torch.randn([1024, 4, 2]),
6502                torch.arange(3),
6503                torch.randn([1024, 1, 2]),
6504            ],
6505        )
6506
6507    def test_index_put4(self):
6508        # a, b[0] are not broadcastable
6509        # https://github.com/pytorch/pytorch/issues/97104
6510        def fn(a, b, c):
6511            return torch.index_put(a, [b], c)
6512
6513        self.common(
6514            fn,
6515            [
6516                torch.rand([8, 2]),
6517                torch.rand([8]) > 0.5,
6518                torch.rand([]),
6519            ],
6520        )
6521
6522    def test_index_put_as_masked_fill(self):
6523        def fn(a, b, c, d):
6524            a = a.clone()
6525            torch.ops.aten.index_put_(a, [b], c, d)
6526            return a
6527
6528        self.common(
6529            fn,
6530            (
6531                torch.randn([1024, 4, 2]),
6532                torch.randn([1024, 4, 2]) > 0,
6533                torch.randn([]),
6534                False,
6535            ),
6536        )
6537
6538        self.common(
6539            fn,
6540            (
6541                torch.randn([1024, 4, 2]),
6542                torch.randn([1024, 4, 2]) > 0,
6543                torch.randn([]),
6544                True,
6545            ),
6546        )
6547
6548    def test_index_put_fallback1(self):
6549        def fn(a, b, c, d):
6550            a = a.clone()
6551            torch.ops.aten.index_put_(a, [b], c, d)
6552            return a
6553
6554        self.common(
6555            fn,
6556            (
6557                torch.randn([3]),
6558                torch.as_tensor([True, True, False]),
6559                torch.randn([2]),
6560                False,
6561            ),
6562        )
6563
6564        self.common(
6565            fn,
6566            (
6567                torch.randn([3]),
6568                torch.as_tensor([True, True, False]),
6569                torch.randn([2]),
6570                True,
6571            ),
6572        )
6573
6574    def test_index_put_fallback2(self):
6575        def fn(a, b, c, d, e):
6576            a = a.clone()
6577            torch.ops.aten.index_put_(a, [None, b, c], d, e)
6578            return a
6579
6580        self.common(
6581            fn,
6582            (
6583                torch.randn([1, 2, 3]),
6584                torch.as_tensor([0, 1]),
6585                torch.as_tensor([True, True, False]),
6586                torch.randn([]),
6587                False,
6588            ),
6589        )
6590        self.common(
6591            fn,
6592            (
6593                torch.randn([1, 2, 3]),
6594                torch.as_tensor([0, 1]),
6595                torch.as_tensor([True, True, False]),
6596                torch.randn([]),
6597                True,
6598            ),
6599        )
6600
6601    def test_index_put_deterministic_fallback(self):
6602        with DeterministicGuard(True):
6603
6604            def fn(a, b, c):
6605                return torch.index_put(a, [b], c, True)
6606
6607            self.common(
6608                fn,
6609                [
6610                    torch.randn([100, 32]),
6611                    torch.randint(0, 100, size=[600], dtype=torch.int64),
6612                    torch.randn([600, 32]),
6613                ],
6614                check_lowp=False,
6615            )
6616
6617    def test_index_put_index(self):
6618        def fn(ind, x, src):
6619            y = torch.ops.aten.index_put.default(x, [ind], src)
6620            return torch.ops.aten.index.Tensor(y, [ind])
6621
6622        args = [torch.tensor([1], dtype=torch.int64), torch.randn(8, 4), torch.randn(4)]
6623        self.common(fn, args)
6624
6625    def test_index_put_reinplace(self):
6626        def fn(x, idx):
6627            src = torch.ones(idx.size(0), device=x.device)
6628            x.index_put_((idx,), src)
6629            return x.expand((2, x.shape[0]))
6630
6631        a = torch.randn(1024)
6632        idx = torch.arange(10)
6633        torch._inductor.metrics.generated_kernel_count = 0
6634        self.common(fn, (a, idx))
6635        assertGeneratedKernelCountEqual(self, 1)
6636
6637    def test_index_put_failed_reinplace(self):
6638        def fn(x, idx):
6639            src = torch.ones(idx.size(0), device=x.device)
6640            y = x.index_put((idx,), src)
6641            return x, y
6642
6643        a = torch.randn(1024)
6644        idx = torch.arange(10)
6645        torch._inductor.metrics.generated_kernel_count = 0
6646        self.common(fn, (a, idx))
6647        assertGeneratedKernelCountEqual(self, 2)
6648
6649    def test_adding_tensor_offsets(self):
6650        @torch.compile(fullgraph=True)
6651        def fn(x):
6652            return x[16:32]
6653
6654        with torch.no_grad():
6655            x = torch.randn(1024, device=self.device)
6656            self.assertEqual(fn(x[0:]), x[16:][:16])
6657            self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
6658
6659    # from GPT2ForSequenceClassification
6660    def test_index_tensor(self):
6661        def fn(x, y):
6662            ne = torch.ops.aten.ne.Scalar(x, 0)
6663            sum = torch.ops.aten.sum.dim_IntList(ne, [-1])
6664            sub = torch.ops.aten.sub.Tensor(sum, 1)
6665            iota = torch.ops.prims.iota.default(
6666                1,
6667                start=0,
6668                step=1,
6669                dtype=torch.int64,
6670                device=x.device,
6671                requires_grad=False,
6672            )
6673            return torch.ops.aten.index.Tensor(y, [iota, sub])
6674
6675        self.common(fn, [torch.randn(1, 1024), torch.randn(1, 1024, 2)])
6676
6677    @config.patch(fallback_random=True)
6678    def test_bernoulli1(self):
6679        def fn(a):
6680            b = torch.empty_like(a)
6681            return aten.bernoulli_(b), b
6682
6683        self.common(
6684            fn,
6685            [
6686                torch.randn([100]),
6687            ],
6688        )
6689
6690    def test_bernoulli2(self):
6691        def fn(a):
6692            return aten.bernoulli(a)
6693
6694        self.common(
6695            fn,
6696            [torch.tensor([1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0])],
6697        )
6698
6699    def test_narrow(self):
6700        def fn(x):
6701            return (
6702                aten.narrow(x, 1, 10, 16),
6703                aten.narrow(x + 2, 0, 10, 16) + 1,
6704                aten.narrow_copy(x, 1, 10, 16),
6705            )
6706
6707        self.common(fn, [torch.randn(64, 64)])
6708
6709    def test_new_cpp_build_logical(self):
6710        from torch._inductor.codecache import validate_new_cpp_commands
6711
6712        validate_new_cpp_commands()
6713
6714    def test_as_strided(self):
6715        def fn(x):
6716            return (
6717                aten.as_strided(x, (8, 8, 64), (8 * 64, 64, 1), 0),
6718                aten.as_strided(x + 1, (8, 8, 64), (8 * 64, 64, 1), 0) + 2,
6719            )
6720
6721        def fn_channels_last(x):
6722            return (
6723                aten.as_strided(
6724                    x, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680), 0
6725                ),
6726                aten.as_strided(
6727                    x + 1, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680), 0
6728                )
6729                + 2,
6730            )
6731
6732        self.common(fn, [torch.randn(64, 64)])
6733        self.common(
6734            fn_channels_last,
6735            [torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)],
6736        )
6737
6738    def test_like_channels_last(self):
6739        def foo():
6740            randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32)
6741            xc = randn.contiguous(memory_format=torch.channels_last)
6742            clone = torch.zeros_like(xc, memory_format=torch.preserve_format)
6743            rand_like = torch.rand_like(randn)
6744            return (xc, clone, rand_like)
6745
6746        out = foo()
6747        out_comp = torch.compile()(foo)()
6748
6749        for t, t_comp in zip(out, out_comp):
6750            self.assertEqual(t.stride(), t_comp.stride())
6751
6752    def test_as_strided_scatter(self):
6753        def fn(a, b):
6754            return aten.as_strided_scatter(
6755                a * 8 + 10,
6756                b * 2 - 4,
6757                size=(a.shape[0], a.shape[1] // 2),
6758                stride=(a.shape[1], 2),
6759                storage_offset=0,
6760            )
6761
6762        self.common(fn, [torch.randn(10, 1024), torch.randn(10, 512)])
6763
6764    def test_select_scatter(self):
6765        def fn(x, a, b):
6766            return (
6767                aten.select_scatter(x, a, 1, 0),
6768                aten.select_scatter(x, b, 0, 1),
6769            )
6770
6771        self.common(
6772            fn,
6773            [
6774                torch.randn(8, 197, 38),
6775                torch.randn(8, 38),
6776                torch.randn(197, 38),
6777            ],
6778        )
6779
6780    def test_slice_scatter(self):
6781        def fn(x, a):
6782            return (
6783                aten.slice_scatter(x, a, 2, 10, -10),
6784                aten.slice_scatter(x, a[:, :, :40], 2, 10, -10, 2),
6785            )
6786
6787        self.common(
6788            fn,
6789            [
6790                torch.randn(4, 8, 100),
6791                torch.randn(4, 8, 80),
6792            ],
6793        )
6794
6795    def test_slice_scatter2(self):
6796        def fn(a, b):
6797            return aten.slice_scatter(a, b, 0, 0, 9223372036854775807)
6798
6799        self.common(
6800            fn,
6801            [
6802                torch.randn([8, 197, 384]),
6803                torch.randn([8, 197, 384]),
6804            ],
6805        )
6806
6807    def test_slice_scatter3(self):
6808        def fn(a, b):
6809            return aten.slice_scatter.default(a, b, 1, 1, 9223372036854775807, 2)
6810
6811        self.common(
6812            fn,
6813            [
6814                torch.randn([1, 4]),
6815                torch.randn([1, 2]),
6816            ],
6817        )
6818
6819    def test_slice_scatter4(self):
6820        def fn(a, b):
6821            return aten.slice_scatter.default(a, b, 1, 2, 9223372036854775807, 3)
6822
6823        self.common(
6824            fn,
6825            [
6826                torch.randn([1, 9]),
6827                torch.randn([1, 3]),
6828            ],
6829        )
6830
6831    def test_slice_scatter5(self):
6832        # empty slices that require clamping the start or end
6833        def fn(a, b):
6834            return (
6835                aten.slice_scatter.default(a, b, 0, 2, 0, 1),
6836                aten.slice_scatter.default(a, b, 0, a.shape[0], a.shape[0] + 10, 1),
6837                aten.slice_scatter.default(a, b, 0, -20, 0, 1),
6838                aten.slice_scatter.default(a, b, 0, -20, -16, 1),
6839            )
6840
6841        a = torch.arange(10, dtype=torch.float)
6842        b = torch.empty(0)
6843        self.common(fn, [a, b])
6844
6845    def test_slice_scatter_reinplace(self):
6846        class M(nn.Module):
6847            def __init__(self, device):
6848                super().__init__()
6849                self.linear1 = nn.Linear(64, 64, bias=False)
6850                self.cache_k = torch.zeros((56, 384, 8, 64), device=device)
6851
6852            def forward(self, x, start_pos):
6853                bsz, seqlen, _, _ = x.shape
6854                xk = self.linear1(x)
6855                with torch.no_grad():
6856                    self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
6857                keys = self.cache_k[:bsz, : start_pos + seqlen]
6858                scores = torch.matmul(
6859                    xk.transpose(1, 2), keys.transpose(1, 2).transpose(2, 3)
6860                )
6861                return scores
6862
6863        kv_cache_module = M(self.device)
6864        inp = torch.randn(1, 32, 8, 64)
6865
6866        # Test that the cache update is reinplaced such that the cache is updated inplace
6867        # rather than copy-scatter-copy-back.
6868
6869        torch._inductor.metrics.generated_kernel_count = 0
6870        with torch.no_grad():
6871            self.common(kv_cache_module, (inp, 1), check_lowp=False)
6872        assertGeneratedKernelCountEqual(self, 1)
6873
6874    def test_scatter1(self):
6875        def fn(a, dim, index, b):
6876            return aten.scatter(a, dim, index, b)
6877
6878        self.common(
6879            fn,
6880            [
6881                torch.zeros(2, 3),
6882                -1,
6883                torch.tensor([[0]]),
6884                torch.ones(2, 3),
6885            ],
6886        )
6887
6888    def test_scatter2(self):
6889        if self.device == "cuda":
6890            raise unittest.SkipTest("unstable on sm86")
6891
6892        check_lowp = True
6893        if self.device == "xpu":
6894            check_lowp = False
6895
6896        def fn(a, dim, index, b):
6897            return aten.scatter.reduce(a, dim, index, b, reduce="add")
6898
6899        self.common(
6900            fn,
6901            [
6902                torch.zeros(64, 512),
6903                0,
6904                torch.zeros((64, 512), dtype=torch.int64),
6905                torch.ones(64, 512),
6906            ],
6907            check_lowp=check_lowp,
6908        )
6909
6910    def test_scatter3(self):
6911        def fn(a, dim, index, b):
6912            return aten.scatter(a, dim, index, b, reduce="add")
6913
6914        check_lowp = True
6915        if self.device == "xpu":
6916            check_lowp = False
6917
6918        self.common(
6919            fn,
6920            [
6921                torch.randn(5, 29, 13),
6922                2,
6923                torch.tensor([[[3, 5, 7, 9]]]),
6924                0.8,  # src can be a scalar
6925            ],
6926            # Mismatched elements: 1 / 1885 (0.1%)
6927            # Greatest absolute difference: 0.00018310546875 at index (0, 0, 3) (up to 1e-05 allowed)
6928            # Greatest relative difference: 0.0022371364653243847 at index (0, 0, 3) (up to 0.001 allowed)
6929            atol=2e-4,
6930            rtol=1e-3,
6931            check_lowp=check_lowp,
6932        )
6933
6934    def test_scatter4(self):
6935        def fn(x, ind, src):
6936            return torch.scatter(x, 0, ind, src)
6937
6938        check_lowp = True
6939        if self.device == "xpu":
6940            check_lowp = False
6941
6942        for deterministic in [False, True]:
6943            with DeterministicGuard(deterministic):
6944                self.common(
6945                    fn,
6946                    [
6947                        torch.randn(196, 992),
6948                        torch.randint(196, (1, 992)),
6949                        torch.randn(1, 992),
6950                    ],
6951                    check_lowp=check_lowp,
6952                )
6953
6954    def test_scatter5(self):
6955        def fn(a, dim, index, b, reduce):
6956            a = a.clone()
6957            a.scatter_(dim, index, b, reduce=reduce)
6958            a1 = a + 1.0
6959            a1.scatter_(dim, index, b, reduce=reduce)
6960            return (a, a1)
6961
6962        check_lowp = True
6963        if self.device == "xpu":
6964            check_lowp = False
6965
6966        for reduce in ["add", "multiply"]:
6967            self.common(
6968                fn,
6969                [
6970                    torch.ones((4, 5)),
6971                    0,
6972                    torch.tensor([[1], [2], [3]], dtype=torch.int64),
6973                    torch.randn(4, 5),
6974                    reduce,
6975                ],
6976                check_lowp=check_lowp,
6977            )
6978
6979    def test_scatter6(self):
6980        def fn(a, dim, index, b):
6981            return aten.scatter(a, dim, index, b)
6982
6983        check_lowp = True
6984        if self.device == "xpu":
6985            check_lowp = False
6986
6987        for deterministic in [False, True]:
6988            with DeterministicGuard(deterministic):
6989                self.common(
6990                    fn,
6991                    [
6992                        torch.randn(5, 8, 13),
6993                        2,
6994                        torch.tensor([[[3, 5, 7, 9]]]),
6995                        0.8,  # src can be a scalar
6996                    ],
6997                    check_lowp=check_lowp,
6998                )
6999
7000    @unittest.skip("Flaky test, needs debugging")
7001    def test_scatter_add1(self):
7002        def fn(a, dim, index, b):
7003            return aten.scatter_add(a, dim, index, b)
7004
7005        check_lowp = True
7006        if self.device == "xpu":
7007            check_lowp = False
7008
7009        self.common(
7010            fn,
7011            [
7012                torch.randn(2, 3),
7013                0,
7014                torch.tensor([[0]]),
7015                torch.randn(2, 3),
7016            ],
7017            check_lowp=check_lowp,
7018        )
7019
7020    def test_scatter_add2(self):
7021        def fn(a, dim, index, b):
7022            return aten.scatter_add(a, dim, index, b)
7023
7024        check_lowp = True
7025        if self.device == "xpu":
7026            check_lowp = False
7027
7028        self.common(
7029            fn,
7030            [
7031                torch.randn(2, 3),
7032                0,
7033                torch.tensor([[0, 0, 0], [1, 1, 1]]),
7034                torch.randn(2, 3),
7035            ],
7036            check_lowp=check_lowp,
7037        )
7038
7039    def test_scatter_add3(self):
7040        def fn(a, dim, index, b):
7041            return aten.scatter_add(a, dim, index, b)
7042
7043        check_lowp = True
7044        if self.device == "xpu":
7045            check_lowp = False
7046
7047        for deterministic in [False, True]:
7048            with DeterministicGuard(deterministic):
7049                self.common(
7050                    fn,
7051                    [
7052                        torch.randn(5, 29, 13),
7053                        2,
7054                        torch.tensor([[[3, 5, 7, 9]]]),
7055                        torch.randn(1, 1, 10),
7056                    ],
7057                    check_lowp=check_lowp,
7058                )
7059
7060    def test_scatter_reduce1(self):
7061        def fn(a, dim, index, b):
7062            return aten.scatter_reduce(a, dim, index, b, "sum")
7063
7064        check_lowp = True
7065        if self.device == "xpu":
7066            check_lowp = False
7067
7068        self.common(
7069            fn,
7070            [
7071                torch.randn(5, 29, 13),
7072                2,
7073                torch.tensor([[[3, 5, 7, 9]]]),
7074                torch.randn(1, 1, 10),
7075            ],
7076            check_lowp=check_lowp,
7077        )
7078
7079    def test_scatter_reduce2(self):
7080        def fn(a, dim, index, b, reduce):
7081            return aten.scatter_reduce(a, dim, index, b, reduce, include_self=False)
7082
7083        check_lowp = True
7084        if self.device == "xpu":
7085            check_lowp = False
7086
7087        for reduce in ["sum", "amax"]:
7088            self.common(
7089                fn,
7090                [
7091                    torch.randn(2, 3),
7092                    0,
7093                    torch.zeros((2, 3), dtype=torch.int64),
7094                    torch.randn(2, 3),
7095                    reduce,
7096                ],
7097                check_lowp=check_lowp,
7098            )
7099
7100    def test_scatter_reduce3(self):
7101        def fn(a, dim, index, b, reduce):
7102            a = a.clone()
7103            a.scatter_reduce_(dim, index, b, reduce=reduce)
7104            a1 = a + 1.0
7105            a1.scatter_reduce_(dim, index, b, reduce=reduce)
7106            return (a, a1)
7107
7108        check_lowp = True
7109        if self.device == "xpu":
7110            check_lowp = False
7111
7112        for reduce in ["sum", "prod"]:
7113            self.common(
7114                fn,
7115                [
7116                    torch.ones((4, 5)),
7117                    0,
7118                    torch.tensor([[1], [2], [3]], dtype=torch.int64),
7119                    torch.randn(4, 5),
7120                    reduce,
7121                ],
7122                check_lowp=check_lowp,
7123            )
7124
7125    def test_dense_mask_index(self):
7126        r"""
7127        There will be a little difference for reduce order between aten and inductor
7128        https://github.com/pytorch/pytorch/pull/122289
7129        Absolute difference: 0.00067138671875 (up to 1e-05 allowed)
7130        Relative difference: 3.1747371732500974e-06 (up to 1.3e-06 allowed)
7131        """
7132        kwargs = {}
7133        if self.device == "cpu":
7134            kwargs["atol"] = 1e-4
7135            kwargs["rtol"] = 1.3e-5
7136
7137        def fn(x, y):
7138            y = torch.ops.aten.select.int(y, 0, 2)
7139            z = x * y
7140            return z.sum()
7141
7142        self.common(fn, [torch.randn(102400), torch.randn(3)], **kwargs)
7143
7144    def test_empty1(self):
7145        def fn():
7146            return torch.empty((1, 128, 128))
7147
7148        self.common(fn, [], assert_equal=False)
7149
7150    def test_empty2(self):
7151        def fn():
7152            return aten.empty((1, 128, 128))
7153
7154        self.common(fn, [], assert_equal=False)
7155
7156    def test_new_empty(self):
7157        def fn(a):
7158            return aten.new_empty(a, [1, 128, 128])
7159
7160        self.common(fn, [torch.randn(55)], assert_equal=False)
7161
7162    def test_empty_strided(self):
7163        def fn():
7164            return aten.empty_strided([1, 128, 128], [16384, 128, 1])
7165
7166        self.common(fn, [], assert_equal=False)
7167
7168    def test_new_empty_strided(self):
7169        def fn(a):
7170            return aten.new_empty_strided(a, [1, 128, 128], [16384, 128, 1])
7171
7172        self.common(fn, [torch.randn(55)], assert_equal=False)
7173
7174    def test_dropout_trivial_0(self):
7175        def fn1(a):
7176            return torch.nn.functional.dropout(a, 0.0, True) + a
7177
7178        self.common(fn1, [torch.randn(55)])
7179
7180    def test_dropout_trivial_1(self):
7181        def fn2(a):
7182            return torch.nn.functional.dropout(a, 1.0, True) + a
7183
7184        self.common(fn2, [torch.randn(55)])
7185
7186    @config.patch({"triton.cudagraphs": True})
7187    @dynamo_config.patch(automatic_dynamic_shapes=True)
7188    def test_dropout(self):
7189        random.seed(1234)
7190        torch.manual_seed(1234)
7191
7192        @torch._dynamo.optimize("inductor")
7193        def fn1(a):
7194            return torch.nn.functional.dropout(a)
7195
7196        x = torch.ones(1000, device=self.device, dtype=torch.float32)
7197        result1 = fn1(x)
7198        self.assertTrue(400 < result1.nonzero().shape[0] < 600)
7199        self.assertTrue(0.9 < result1.mean().item() < 1.1)
7200
7201        random.seed(1234)
7202        torch.manual_seed(1234)
7203
7204        @torch._dynamo.optimize("inductor")
7205        def fn2(a):
7206            return torch.nn.functional.dropout(a, 0.5, True)
7207
7208        result2 = fn2(x)
7209        self.assertTrue(400 < result2.nonzero().shape[0] < 600)
7210        self.assertTrue(0.9 < result2.mean().item() < 1.1)
7211
7212    @dynamo_config.patch(automatic_dynamic_shapes=True)
7213    def test_dropout_deterministic(self):
7214        @torch._dynamo.optimize("inductor")
7215        def fn(a):
7216            return torch.nn.functional.dropout(a, 0.55, True)
7217
7218        for cg in [False, True]:
7219            with patch.object(config.triton, "cudagraphs", cg):
7220                torch._dynamo.reset()
7221
7222                x = torch.ones(1024, device=self.device, dtype=torch.float32)
7223
7224                torch.manual_seed(1234)
7225                a0 = fn(x).clone()
7226                a1 = fn(x).clone()
7227                a2 = fn(x).clone()
7228
7229                torch.manual_seed(1234)
7230                b0 = fn(x).clone()
7231                b1 = fn(x).clone()
7232                b2 = fn(x).clone()
7233
7234                # same seed, same values
7235                self.assertTrue(torch.allclose(a0, b0))
7236                self.assertTrue(torch.allclose(a1, b1))
7237                self.assertTrue(torch.allclose(a2, b2))
7238
7239                # different calls, different values
7240                self.assertFalse(torch.allclose(a0, a1))
7241                self.assertFalse(torch.allclose(a1, a2))
7242
7243    def test_rand_like_deterministic(self):
7244        @torch._dynamo.optimize("inductor")
7245        def fn(a):
7246            return torch.rand_like(a), torch.rand_like(a)
7247
7248        x = torch.ones(1024, device=self.device, dtype=torch.float32)
7249
7250        torch.manual_seed(1234)
7251        a0 = fn(x)[0].clone()
7252        a1 = fn(x)[0].clone()
7253        a2 = fn(x)[0].clone()
7254
7255        torch.manual_seed(1234)
7256        b0 = fn(x)[0].clone()
7257        b1 = fn(x)[0].clone()
7258        b2 = fn(x)[0].clone()
7259
7260        # same seed, same values
7261        self.assertTrue(torch.allclose(a0, b0))
7262        self.assertTrue(torch.allclose(a1, b1))
7263        self.assertTrue(torch.allclose(a2, b2))
7264
7265        # different calls, different values
7266        self.assertFalse(torch.allclose(a0, a1))
7267        self.assertFalse(torch.allclose(a1, a2))
7268
7269        c, d = fn(x)
7270        self.assertFalse(torch.allclose(c, d))
7271        self.assertTrue((c >= 0).all())
7272        self.assertTrue((c < 1).all())
7273        self.assertTrue((d >= 0).all())
7274        self.assertTrue((d < 1).all())
7275
7276    @config.patch(implicit_fallbacks=True)
7277    def test_fallback_mutable_op_basic(self):
7278        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
7279
7280            def impl(a, b, c, d, e=2):
7281                a.add_(b[0] * c * e),
7282                if d is not None:
7283                    d.add_(b[1])
7284
7285            m.define(
7286                "inplace_(Tensor(a!) a, Tensor[] b, SymInt c, *, Tensor(b!)? d, SymInt e=2) -> ()"
7287            )
7288            m.impl("inplace_", impl, "CompositeExplicitAutograd")
7289
7290            # We do some clones and copy_ to test that Inductor doesn't reorder
7291            # the copy_ w.r.t. inplace_.
7292            def f(a, b1, b2, c, d):
7293                a_ = a.clone()
7294                d_ = d if d is None else d.clone()
7295                torch.ops.mylib.inplace_(a_, (b1, b2), c, d=d_)
7296                a.copy_(a_)
7297                if d is not None:
7298                    d.copy_(d_)
7299                return ()
7300
7301            a = torch.tensor([0.0, 1.0, 2])
7302            b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])]
7303            c = 4
7304            d = torch.tensor([2.0, 1, 0])
7305            args = (a, b[0], b[1], c, d)
7306            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7307            mod = make_fx(f)(*cloned_args)
7308            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7309            compiled_f = compile_fx_inner(mod, cloned_args)
7310
7311            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7312            compiled_f(list(cloned_args))
7313            f(*args)
7314            self.assertEqual(cloned_args, args)
7315
7316    @config.patch(implicit_fallbacks=True)
7317    def test_fallback_mutable_op_with_return(self):
7318        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
7319
7320            def impl(a, b, c, d, e=2):
7321                a.add_(b[0] * c * e),
7322                if d is not None:
7323                    d.add_(b[1])
7324                return b[0] + b[1]
7325
7326            m.define(
7327                "inplace_(Tensor(a!) a, Tensor[] b, SymInt c, *, Tensor(b!)? d, SymInt e=2) -> Tensor"
7328            )
7329            m.impl("inplace_", impl, "CompositeExplicitAutograd")
7330
7331            # We do some clones and copy_ to test that Inductor doesn't reorder
7332            # the copy_ w.r.t. inplace_.
7333            def f(a, b0, b1, c, d):
7334                a_ = a.clone()
7335                d_ = d if d is None else d.clone()
7336                res = torch.ops.mylib.inplace_(a_, (b0, b1), c, d=d_)
7337                a.copy_(a_)
7338                if d is not None:
7339                    d.copy_(d_)
7340                return (res,)
7341
7342            a = torch.tensor([0.0, 1.0, 2])
7343            b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])]
7344            c = 4
7345            d = torch.tensor([2.0, 1, 0])
7346            args = (a, b[0], b[1], c, d)
7347
7348            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7349            mod = make_fx(f)(*cloned_args)
7350            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7351            compiled_f = compile_fx_inner(mod, cloned_args)
7352
7353            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7354            compiled_out = compiled_f(list(cloned_args))
7355            out = f(*args)
7356            self.assertEqual(cloned_args, args)
7357            self.assertEqual(compiled_out, out)
7358
7359    @config.patch(implicit_fallbacks=True)
7360    def test_fallback_mutable_op_no_mutated_tensors(self):
7361        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
7362
7363            def impl(a, b):
7364                if b is not None:
7365                    b.add_(1)
7366
7367            m.define("inplace_(Tensor a, Tensor(b!)? b) -> ()")
7368            m.impl("inplace_", impl, "CompositeExplicitAutograd")
7369
7370            def f(a):
7371                torch.ops.mylib.inplace_(a, None)
7372                return ()
7373
7374            a = torch.tensor([0.0, 1.0, 2])
7375            args = (a,)
7376            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7377            mod = make_fx(f)(*cloned_args)
7378            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7379            compiled_f = compile_fx_inner(mod, cloned_args)
7380
7381            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7382            compiled_f(list(cloned_args))
7383            f(*args)
7384            self.assertEqual(cloned_args, args)
7385
7386    @config.patch(implicit_fallbacks=True)
7387    def test_fallback_mutable_op_list(self):
7388        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
7389
7390            def impl(a, b):
7391                for bi in b:
7392                    bi.add_(a)
7393
7394            m.define("inplace_(Tensor a, Tensor(a!)[] b) -> ()")
7395            m.impl("inplace_", impl, "CompositeExplicitAutograd")
7396
7397            def f(a, b):
7398                torch.ops.mylib.inplace_(a, b)
7399                return ()
7400
7401            a = torch.tensor([0.0, 1.0, 2])
7402            b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])]
7403            args = (a, b)
7404            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7405            mod = make_fx(f)(*cloned_args)
7406            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
7407
7408            with self.assertRaisesRegex(
7409                torch._inductor.exc.LoweringException,
7410                "NYI: Can't generate FallbackKernel",
7411            ):
7412                compiled_f = compile_fx_inner(mod, cloned_args)
7413
7414    @expectedFailureXPU
7415    def test_functionalize_rng_wrappers(self):
7416        # Ideally, we would like to use torch.compile for these operators. But
7417        # currently the plan is to introduce these operators at the partitioner
7418        # level, obviating the need to support them fully through the
7419        # torch.compile stack. To ensure that we have good enough debugging with
7420        # minifiers, we have ensure that they work with make_fx. This test uses
7421        # make_fx to do the testing. In future, we can move on torch.compile.
7422        def fn():
7423            rng_state1, a1 = torch._prims.rng_prims.run_and_save_rng_state(
7424                torch.ops.aten.rand.default,
7425                [4, 4],
7426                dtype=torch.float32,
7427                device=self.device,
7428            )
7429            rng_state2, a2 = torch._prims.rng_prims.run_and_save_rng_state(
7430                torch.ops.aten.rand.default,
7431                [4, 4],
7432                dtype=torch.float32,
7433                device=self.device,
7434            )
7435
7436            b1 = torch._prims.rng_prims.run_with_rng_state(
7437                rng_state1,
7438                torch.ops.aten.rand.default,
7439                [4, 4],
7440                dtype=torch.float32,
7441                device=self.device,
7442            )
7443            b2 = torch._prims.rng_prims.run_with_rng_state(
7444                rng_state2,
7445                torch.ops.aten.rand.default,
7446                [4, 4],
7447                dtype=torch.float32,
7448                device=self.device,
7449            )
7450
7451            return (a1, a2, b1, b2)
7452
7453        mod = make_fx(fn)()
7454        compiled_f = compile_fx_inner(mod, ())
7455        a1, a2, b1, b2 = compiled_f(())
7456        self.assertEqual(a1, b1)
7457        self.assertEqual(a2, b2)
7458
7459    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
7460    @expectedFailureXPU
7461    def test_philox_rand(self):
7462        if self.device == "cpu":
7463            raise unittest.SkipTest(
7464                f"functionalization of rng ops supported only on {GPU_TYPE}"
7465            )
7466
7467        @torch._dynamo.optimize("inductor")
7468        def fn(x):
7469            a = torch.rand_like(x) * x
7470            a = torch.rand_like(x) * a
7471            return a
7472
7473        def check(x):
7474            torch.manual_seed(123)
7475            a = fn(x)
7476
7477            torch.manual_seed(1234)
7478            b = fn(x)
7479
7480            torch.manual_seed(123)
7481            c = fn(x)
7482
7483            # same seed, same values
7484            self.assertTrue(torch.allclose(a, c))
7485
7486            # different calls, different values
7487            self.assertFalse(torch.allclose(a, b))
7488
7489        check(torch.ones(1024, device=self.device, dtype=torch.float32))
7490        # Need comment: should we add "_get_rng_state_offset" to common device interface?
7491        self.assertEqual(getattr(torch, self.device)._get_rng_state_offset(), 2048)
7492        # Check non-multiple of 4 numel
7493        check(torch.ones(3, device=self.device, dtype=torch.float32))
7494        self.assertEqual(getattr(torch, self.device)._get_rng_state_offset(), 8)
7495
7496    # Already on by default, just want to make sure
7497    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
7498    def test_reuse_buffers_with_aliasing(self):
7499        def f(x):
7500            z = x + 1
7501            z = torch.view_as_complex(z)
7502            a = torch.view_as_real(z)
7503            out = a + 1
7504            return out, torch.view_as_real(z + 1)
7505
7506        self.common(f, (torch.zeros((4, 2)),))
7507
7508        code = run_and_get_triton_code(torch.compile(f), torch.zeros((4, 2)))
7509        # Make sure that we haven't added complex support and made this test
7510        # invalid. If we've added complex support please update the test to use
7511        # a different set of view ops we don't lower
7512        self.assertTrue("aten.view_as_real" in code)
7513
7514        def f2(x):
7515            z = x + 1
7516            z = torch.view_as_complex(z)
7517            z = torch.view_as_real(z)
7518            z = torch.view_as_complex(z)
7519            a = torch.view_as_real(z)
7520            out = a + 1
7521            return out, torch.view_as_real(z + 1)
7522
7523        self.common(f, (torch.zeros((4, 2)),))
7524
7525    def test_randn_like_empty(self):
7526        class Model(torch.nn.Module):
7527            def __init__(
7528                self,
7529            ):
7530                super().__init__()
7531
7532            def forward(self, v1: torch.Tensor):
7533                vx = v1.min(dim=1).values
7534                v2 = torch.randn_like(vx)
7535                return v2
7536
7537        model = Model()
7538        x = torch.rand(10, 3, 0)
7539
7540        self.common(model, (x,))
7541
7542    def test_randint(self):
7543        @torch.compile(fullgraph=True)
7544        def fn(x):
7545            return (
7546                torch.randint(10, [1024], device=x.device),
7547                torch.randint(-4, 7, [1024], dtype=torch.int32, device=x.device),
7548                torch.randint_like(x, 2**50),
7549            )
7550
7551        torch.manual_seed(12345)
7552        a0, b0, c0 = fn(torch.zeros([40, 40], device=self.device))
7553        self.assertEqual(a0.shape, [1024])
7554        self.assertEqual(b0.shape, [1024])
7555        self.assertEqual(c0.shape, [40, 40])
7556        torch.manual_seed(12345)
7557        a1, b1, c1 = fn(torch.zeros([40, 40], device=self.device))
7558        self.assertEqual(a0, a1)
7559        self.assertEqual(b0, b1)
7560        self.assertEqual(c0, c1)
7561
7562        self.assertEqual(a0.min(), 0)
7563        self.assertEqual(a0.max(), 9)
7564
7565        self.assertEqual(b0.min(), -4)
7566        self.assertEqual(b0.max(), 6)
7567
7568        self.assertGreaterEqual(c0.min(), 0)
7569        self.assertGreater(c0.max(), 2**40)
7570        self.assertLess(c0.max(), 2**50)
7571
7572    @config.patch(fallback_random=True)
7573    def test_like_rands(self):
7574        def fn(x):
7575            return torch.rand_like(x), torch.randn_like(x)
7576
7577        self.common(fn, [torch.zeros([20, 20])])
7578
7579    def test_like_rands2(self):
7580        # rand_like with kwargs `device` of str type
7581        d = self.device
7582        assert isinstance(d, str)
7583
7584        @torch.compile
7585        def fn(x):
7586            return torch.rand_like(x, device=d)
7587
7588        x = torch.ones(10, device=self.device, dtype=torch.float32)
7589        a0 = fn(x).clone()
7590        a1 = fn(x).clone()
7591        self.assertFalse(torch.allclose(a0, a1))
7592
7593    @requires_gpu()
7594    def test_like_rands3(self):
7595        # rand_like with `device` which is different from `x.device`
7596        def test_like_rands_on_different_device(device1, device2):
7597            @torch.compile
7598            def fn(x, device):
7599                return torch.rand_like(x, device=device)
7600
7601            x = torch.ones(10, device=device1, dtype=torch.float32)
7602            return fn(x, device2).clone()
7603
7604        a0 = test_like_rands_on_different_device("cpu", GPU_TYPE)
7605        a1 = test_like_rands_on_different_device(GPU_TYPE, "cpu")
7606        self.assertTrue(a0.device.type == GPU_TYPE)
7607        self.assertTrue(a1.device.type == "cpu")
7608
7609    def test_max_pool2d_with_indices_backward(self):
7610        def fn(a, b, c):
7611            return aten.max_pool2d_with_indices_backward(
7612                a, b, [2, 2], [2, 2], [0, 0], [1, 1], False, c
7613            )
7614
7615        x = torch.randn([2, 4, 18, 14])
7616        result, indices = aten.max_pool2d_with_indices(
7617            x,
7618            [2, 2],
7619            [2, 2],
7620            [0, 0],
7621            [1, 1],
7622            False,
7623        )
7624
7625        self.common(
7626            fn,
7627            [
7628                torch.randn_like(result),
7629                x,
7630                indices,
7631            ],
7632        )
7633
7634    def test_max_pool2d_with_indices_backward2(self):
7635        def fn(a, b, c):
7636            return aten.max_pool2d_with_indices_backward(
7637                a, b, [3, 3], [2, 2], [1, 1], [1, 1], True, c
7638            )
7639
7640        x = torch.randn([2, 4, 40, 56])
7641        result, indices = aten.max_pool2d_with_indices(
7642            x,
7643            [3, 3],
7644            [2, 2],
7645            [1, 1],
7646            [1, 1],
7647            True,
7648        )
7649
7650        self.common(
7651            fn,
7652            [
7653                torch.randn_like(result),
7654                x,
7655                indices,
7656            ],
7657        )
7658
7659    # From https://github.com/pytorch/torchdynamo/issues/1200
7660    def test_max_pool2d_with_indices_backward3(self):
7661        def fn(a, b, c):
7662            return aten.max_pool2d_with_indices_backward(
7663                a, b, [1, 1], [2, 2], [0, 0], [1, 1], False, c
7664            )
7665
7666        x = torch.randn([32, 256, 37, 38])
7667        result, indices = aten.max_pool2d_with_indices(
7668            x,
7669            [1, 1],
7670            [2, 2],
7671            0,
7672            1,
7673            False,
7674        )
7675        self.common(
7676            fn,
7677            [
7678                torch.randn_like(result),
7679                x,
7680                indices,
7681            ],
7682        )
7683
7684    # From https://github.com/pytorch/torchdynamo/issues/1352
7685    def test_max_pool2d_with_indices_backward4(self):
7686        def fn(a, b, c):
7687            return aten.max_pool2d_with_indices_backward(
7688                a, b, [5, 5], [1, 1], [2, 2], [1, 1], False, c
7689            )
7690
7691        torch._inductor.metrics.generated_kernel_count = 0
7692        x = torch.randn([2, 64, 3, 4])
7693        result, indices = aten.max_pool2d_with_indices(
7694            x,
7695            [5, 5],
7696            [1, 1],
7697            2,
7698            1,
7699            False,
7700        )
7701        self.common(
7702            fn,
7703            [
7704                torch.randn_like(result),
7705                x,
7706                indices,
7707            ],
7708        )
7709        assertGeneratedKernelCountEqual(self, 1)
7710
7711    @expectedFailureXPU
7712    def test_max_pool2d_with_indices_backward5(self):
7713        # Window size is too big. Should fallback
7714        def fn(a, b, c):
7715            return aten.max_pool2d_with_indices_backward(
7716                a, b, [13, 13], [1, 1], [2, 2], [1, 1], False, c
7717            )
7718
7719        torch._inductor.metrics.generated_kernel_count = 0
7720        x = torch.randn([2, 64, 20, 20])
7721        result, indices = aten.max_pool2d_with_indices(
7722            x,
7723            [13, 13],
7724            [1, 1],
7725            2,
7726            1,
7727            False,
7728        )
7729        self.common(
7730            fn,
7731            [
7732                torch.randn_like(result),
7733                x,
7734                indices,
7735            ],
7736        )
7737        assertGeneratedKernelCountEqual(self, 0)
7738
7739    # From https://github.com/pytorch/pytorch/issues/93384
7740    def test_max_pool2d_with_indices_backward6(self):
7741        # dilation is not 1. Should fallback
7742        def fn(a, b, c):
7743            return aten.max_pool2d_with_indices_backward(
7744                a, b, [3, 2], [2, 1], [1, 1], [1, 2], False, c
7745            )
7746
7747        torch._inductor.metrics.generated_kernel_count = 0
7748        x = torch.randn([2, 2, 3, 6])
7749        result, indices = aten.max_pool2d_with_indices(
7750            x,
7751            [3, 2],
7752            [2, 1],
7753            [1, 1],
7754            [1, 2],
7755            False,
7756        )
7757        self.common(
7758            fn,
7759            [
7760                torch.randn_like(result),
7761                x,
7762                indices,
7763            ],
7764        )
7765        assertGeneratedKernelCountEqual(self, 0)
7766
7767    def test_issue102546(self):
7768        def fn(x):
7769            return x.mean(0)
7770
7771        self.common(fn, [torch.rand(())])
7772
7773    def test_avg_pool2d_backward(self):
7774        def fn(a, b):
7775            return aten.avg_pool2d_backward(
7776                a,
7777                b,
7778                [2, 2],
7779                [2, 2],
7780                [0, 0],
7781                True,
7782                False,
7783                None,
7784            )
7785
7786        self.common(
7787            fn,
7788            [
7789                torch.randn([2, 4, 7, 7]),
7790                torch.randn([2, 4, 14, 14]),
7791            ],
7792        )
7793
7794    def test_avg_pool2d_backward2(self):
7795        def fn(a, b):
7796            return aten.avg_pool2d_backward(
7797                a,
7798                b,
7799                [3, 3],
7800                [1, 1],
7801                [1, 1],
7802                True,
7803                False,
7804                None,
7805            )
7806
7807        self.common(
7808            fn,
7809            [
7810                torch.randn([1, 1, 20, 15]),
7811                torch.randn([1, 1, 20, 15]),
7812            ],
7813        )
7814
7815    def test_avg_pool2d_backward3(self):
7816        def fn(a, b):
7817            return aten.avg_pool2d_backward(
7818                a,
7819                b,
7820                [1, 1],
7821                [2, 2],
7822                [0, 0],
7823                False,
7824                False,
7825                None,
7826            )
7827
7828        torch._inductor.metrics.generated_kernel_count = 0
7829        self.common(
7830            fn,
7831            [
7832                torch.randn([1, 2016, 11, 11]),
7833                torch.randn([1, 2016, 21, 21]),
7834            ],
7835        )
7836        assertGeneratedKernelCountEqual(self, 1)
7837
7838    def test_avg_pool2d_backward4(self):
7839        def fn(a, b):
7840            return aten.avg_pool2d_backward(
7841                a,
7842                b,
7843                [13, 13],
7844                [1, 1],
7845                [0, 0],
7846                True,
7847                False,
7848                None,
7849            )
7850
7851        torch._inductor.metrics.generated_kernel_count = 0
7852        self.common(
7853            fn,
7854            [
7855                torch.randn([1, 16, 12, 12]),
7856                torch.randn([1, 16, 24, 24]),
7857            ],
7858            check_lowp=False,
7859        )
7860        assertGeneratedKernelCountEqual(self, 0)
7861
7862    def test_avg_pool3d_backward(self):
7863        def fn(a, b):
7864            return aten.avg_pool3d_backward(
7865                a,
7866                b,
7867                [2, 2, 2],
7868                [2, 2, 2],
7869                [0, 0, 0],
7870                True,
7871                False,
7872                None,
7873            )
7874
7875        self.common(
7876            fn,
7877            [
7878                torch.randn([2, 4, 7, 7, 7]),
7879                torch.randn([2, 4, 14, 14, 14]),
7880            ],
7881        )
7882
7883    def test_avg_pool3d_backward2(self):
7884        def fn(a, b):
7885            return aten.avg_pool3d_backward(
7886                a,
7887                b,
7888                [3, 3, 3],
7889                [1, 1, 1],
7890                [1, 1, 1],
7891                True,
7892                False,
7893                None,
7894            )
7895
7896        self.common(
7897            fn,
7898            [
7899                torch.randn([1, 1, 20, 20, 15]),
7900                torch.randn([1, 1, 20, 20, 15]),
7901            ],
7902        )
7903
7904    def test_avg_pool3d_backward3(self):
7905        def fn(a, b):
7906            return aten.avg_pool3d_backward(
7907                a,
7908                b,
7909                [1, 1, 1],
7910                [2, 2, 2],
7911                [0, 0, 0],
7912                False,
7913                False,
7914                None,
7915            )
7916
7917        torch._inductor.metrics.generated_kernel_count = 0
7918        self.common(
7919            fn,
7920            [
7921                torch.randn([1, 2016, 11, 11, 11]),
7922                torch.randn([1, 2016, 21, 21, 21]),
7923            ],
7924        )
7925        assertGeneratedKernelCountEqual(self, 1)
7926
7927    def test_avg_pool3d_backward4(self):
7928        def fn(a, b):
7929            return aten.avg_pool3d_backward(
7930                a,
7931                b,
7932                [13, 13, 13],
7933                [1, 1, 1],
7934                [0, 0, 0],
7935                True,
7936                False,
7937                None,
7938            )
7939
7940        torch._inductor.metrics.generated_kernel_count = 0
7941        self.common(
7942            fn,
7943            [
7944                torch.randn([1, 16, 12, 12, 12]),
7945                torch.randn([1, 16, 24, 24, 24]),
7946            ],
7947            check_lowp=False,
7948        )
7949        assertGeneratedKernelCountEqual(self, 0)
7950
7951    @config.patch(search_autotune_cache=False)
7952    def test_mm_views(self):
7953        def fn(a, b):
7954            return torch.mm(a.view(32, 32), b.view(32, 32))
7955
7956        self.common(
7957            fn,
7958            (
7959                torch.randn([32, 32]).transpose(0, 1),
7960                torch.randn([1, 32, 32]).transpose(0, 1),
7961            ),
7962            check_lowp=False,
7963        )
7964        expected_kernel = 0
7965        # codegen mm kernel from template
7966        self.assertEqual(
7967            torch._inductor.metrics.generated_kernel_count, expected_kernel
7968        )
7969
7970    @torch._dynamo.config.patch(assume_static_by_default=False)
7971    def test_dtype_sympy_expr(self):
7972        @torch._dynamo.optimize_assert("inductor")
7973        def fn(a):
7974            y = a[..., :-1, :].contiguous()
7975            return y
7976
7977        result = fn(torch.randn([1, 2, 16, 4]).requires_grad_())
7978        result.sum().backward()
7979
7980    def test_dropout2(self):
7981        n = 100000
7982        weight = torch.ones(
7983            n, device=self.device, dtype=torch.float32, requires_grad=True
7984        )
7985        ones = torch.ones(n, device=self.device, dtype=torch.float32)
7986
7987        @torch._dynamo.optimize_assert("inductor")
7988        def run(x, train=True):
7989            return F.dropout(x * weight, 0.33, train)
7990
7991        def check(r, g):
7992            rmean = r.mean().item()
7993            gmean = g.mean().item()
7994            rcount = len(r.nonzero())
7995            gcount = len(g.nonzero())
7996
7997            # dropped elements should match
7998            self.assertTrue(same(r.nonzero(), g.nonzero()))
7999            self.assertEqual(rcount, gcount)
8000
8001            # dropped should be close to 0.33
8002            self.assertGreater(rcount, 0.64 * n)
8003            self.assertGreater(0.68 * n, rcount)
8004
8005            self.assertAlmostEqual(rmean, gmean)
8006            self.assertAlmostEqual(rmean, 1.0, places=2)
8007
8008        r1 = run(ones, train=False)
8009        r1.sum().backward()
8010        g1 = weight.grad.clone()
8011        # eval mode should be all ones
8012        self.assertTrue(same(r1, torch.ones_like(r1)))
8013        self.assertTrue(same(g1, torch.ones_like(g1)))
8014
8015        torch.manual_seed(1234)
8016        weight.grad.zero_()
8017        r2, (fw_code, bw_code) = run_fw_bw_and_get_code(lambda: run(ones))
8018        if self.device == GPU_TYPE:
8019            self.assertEqual(fw_code.count("tl.rand"), 1)
8020            self.assertEqual(bw_code.count("tl.rand"), 0)
8021        g2 = weight.grad.clone()
8022        check(r2, g2)
8023
8024        torch.manual_seed(1234)
8025        weight.grad.zero_()
8026        r3 = run(ones)
8027        r3.sum().backward()
8028        g3 = weight.grad.clone()
8029        check(r3, g3)
8030
8031        # second run is same result as first
8032        self.assertTrue(same(r2, r3))
8033        self.assertTrue(same(g2, g3))
8034
8035    @config.patch(search_autotune_cache=False)
8036    def test_dropout3(self):
8037        m = torch.nn.Sequential(
8038            torch.nn.Linear(32, 32, bias=False),
8039            torch.nn.Dropout(),
8040            torch.nn.Linear(32, 32, bias=False),
8041            torch.nn.Dropout(),
8042        ).to(self.device)
8043
8044        @torch._dynamo.optimize_assert("inductor")
8045        def run(x):
8046            return m(x)
8047
8048        torch._inductor.metrics.generated_kernel_count = 0
8049
8050        result, (fw_code, bw_code) = run_fw_bw_and_get_code(
8051            lambda: run(torch.randn([8, 32], device=self.device))
8052        )
8053
8054        if self.device == GPU_TYPE:
8055            self.assertEqual(fw_code.count("tl.rand"), 2)
8056            self.assertEqual(bw_code.count("tl.rand"), 0)
8057        expected_kernel = 4
8058
8059        self.assertEqual(
8060            torch._inductor.metrics.generated_kernel_count, expected_kernel
8061        )
8062
8063    def test_randint_kernel_count(self):
8064        @torch._dynamo.optimize_assert("inductor")
8065        def fn1():
8066            random_tensor1 = torch.randint(10, [32], device=self.device)
8067            random_tensor2 = torch.randint(10, [32], device=self.device)
8068            random_tensor3 = torch.randint(10, [32], device=self.device)
8069            return random_tensor1, random_tensor2, random_tensor3
8070
8071        _, source_codes = run_and_get_code(fn1)
8072        if self.device == GPU_TYPE:
8073            self.assertEqual(len(source_codes), 1)
8074            self.assertEqual(source_codes[0].count("async_compile.triton"), 2)
8075
8076    def test_roll(self):
8077        def fn(a):
8078            return (
8079                aten.roll(a, [-3, 10], [1, 2]),
8080                aten.roll(a, [5]),
8081            )
8082
8083        self.common(
8084            fn,
8085            [
8086                torch.randn([2, 56, 56, 16]),
8087            ],
8088        )
8089
8090    def test_argmax_min_int32(self):
8091        # https://github.com/pytorch/pytorch/issues/94055
8092        def fn(a, b):
8093            c = a.argmax(3)
8094            return torch.min(b, c)
8095
8096        a = torch.rand(3, 4, 2, 1).int()
8097        b = torch.rand(2, 2, 1, 4, 1).int()
8098        self.common(fn, (a, b))
8099
8100    def test_argmax_argmin1(self):
8101        def fn(x):
8102            return (aten.argmax(x), aten.argmin(x))
8103
8104        self.common(
8105            fn,
8106            [
8107                torch.randn([8, 256, 256]),
8108            ],
8109        )
8110
8111    def test_argmax_argmin2(self):
8112        def fn(x):
8113            return (
8114                aten.argmax(x, 0),
8115                aten.argmin(x, 0),
8116                aten.argmax(x, 1),
8117                aten.argmin(x, 1),
8118            )
8119
8120        self.common(fn, (torch.randn([144, 144]),))
8121
8122    def test_argmax_argmin_with_duplicates(self):
8123        def fn(x):
8124            return (
8125                aten.argmax(x, 0),
8126                aten.argmin(x, 0),
8127                aten.argmax(x, 1),
8128                aten.argmin(x, 1),
8129            )
8130
8131        # Unrolled reduction
8132        t1 = torch.randint(2, size=(6, 6))
8133        self.common(fn, (t1,))
8134
8135        # Persistent reduction
8136        t1 = torch.randint(8, size=(32, 32))
8137        self.common(fn, (t1,))
8138
8139        # Non-persistent reduction
8140        t1 = torch.randint(8, size=(1028, 1028))
8141        self.common(fn, (t1,))
8142
8143    def test_argmax_argmin_with_nan(self):
8144        def fn(x):
8145            return (
8146                aten.argmax(x, 0),
8147                aten.argmin(x, 0),
8148                aten.argmax(x, 1),
8149                aten.argmin(x, 1),
8150            )
8151
8152        if self.device == "cpu":
8153            raise unittest.SkipTest("broken on CPU")
8154
8155        # Unrolled reduction
8156        t1 = torch.randn((6, 6))
8157        t1[:, 1] = float("nan")
8158        t1[:, 3] = float("nan")
8159        self.common(fn, (t1,))
8160
8161        # Persistent reduction
8162        t1 = torch.randn((32, 32))
8163        t1[:, 4] = float("nan")
8164        t1[:, 8] = float("nan")
8165        self.common(fn, (t1,))
8166
8167        # Non-persistent reduction
8168        t1 = torch.randn((1028, 1028))
8169        t1[:, 40] = float("nan")
8170        t1[:, 100] = float("nan")
8171        self.common(fn, (t1,))
8172
8173    def test_conv_backward(self):
8174        def fn(rank4_inps, rank3_inps, rank5_inps):
8175            out1 = aten.convolution_backward(
8176                *rank4_inps,
8177                [C],
8178                [1, 1],
8179                [0, 0],
8180                [1, 1],
8181                False,
8182                [0, 0],
8183                1,
8184                [True, True, True],
8185            )
8186            out2 = aten.convolution_backward(
8187                *rank4_inps,
8188                [C],
8189                [1, 1],
8190                [0, 0],
8191                [1, 1],
8192                False,
8193                [0, 0],
8194                1,
8195                [True, False, False],
8196            )
8197            out3 = aten.convolution_backward(
8198                *rank3_inps,
8199                [C],
8200                [1],
8201                [0],
8202                [1],
8203                False,
8204                [0],
8205                1,
8206                [True, True, True],
8207            )
8208            out4 = aten.convolution_backward(
8209                *rank5_inps,
8210                [C],
8211                [1, 1, 1],
8212                [0, 0, 0],
8213                [1, 1, 1],
8214                False,
8215                [0, 0, 0],
8216                1,
8217                [True, True, True],
8218            )
8219            return (out1, out2, out3, out4)
8220
8221        B = 3
8222        C = 4
8223        H = 5
8224        grad_out = torch.randn(B, C, H - 2, H - 2, H - 2)
8225        inp = torch.randn(B, C, H, H, H)
8226        weight = torch.randn(C, C, 3, 3, 3)
8227
8228        def shrink_rank(x, rank):
8229            res = x
8230            while res.dim() > rank:
8231                res = torch.select(res, -1, 0)
8232            return res.contiguous()
8233
8234        rank4_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]]
8235        rank3_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]]
8236        rank5_inps = [shrink_rank(x, 5) for x in [grad_out, inp, weight]]
8237
8238        with torch.backends.cudnn.flags(enabled=True, allow_tf32=False):
8239            self.common(
8240                fn,
8241                [rank4_inps, rank3_inps, rank5_inps],
8242            )
8243
8244    @unittest.skip(
8245        """
8246        FIXME: In the case of having equally max/min elements, our implementation returns
8247        the last index instead of the first one
8248        """
8249    )
8250    def test_argmax_argmin3(self):
8251        def fn(x):
8252            return (
8253                aten.argmax(x, 0),
8254                aten.argmin(x, 0),
8255                aten.argmax(x, -1),
8256                aten.argmin(x, -1),
8257            )
8258
8259        self.common(
8260            fn,
8261            [torch.randint(0, 5, [10, 10])],
8262        )
8263
8264    def test_vdd_clamp(self):
8265        def fn(x):
8266            return torch.clamp_min(x, 3)
8267
8268        self.common(
8269            fn,
8270            [
8271                torch.randn([16], requires_grad=True) * 10,
8272            ],
8273        )
8274
8275    def test_tmp_not_defined_issue1(self):
8276        def forward(
8277            primals_3,
8278            primals_4,
8279            add_tensor,
8280            convert_element_type_default,
8281            div_default,
8282            reciprocal_default,
8283        ):
8284            var_default = torch.ops.aten.var(
8285                convert_element_type_default, [2], correction=0
8286            )
8287            sub_tensor = torch.ops.aten.sub.Tensor(add_tensor, div_default)
8288            mul_tensor_1 = torch.ops.aten.mul.Tensor(sub_tensor, reciprocal_default)
8289            mul_tensor_2 = torch.ops.aten.mul.Tensor(mul_tensor_1, primals_3)
8290            add_tensor_2 = torch.ops.aten.add.Tensor(mul_tensor_2, primals_4)
8291            convert_element_type_default_1 = add_tensor_2.to(dtype=torch.float32)
8292            convert_element_type_default_2 = convert_element_type_default_1.to(
8293                dtype=torch.float32
8294            )
8295            var_default_1 = torch.ops.aten.var(
8296                convert_element_type_default_2, [2], correction=0
8297            )
8298            broadcast_in_dim_default_2 = var_default_1.reshape(1, 512, 1)
8299            sum_default_1 = convert_element_type_default_2.sum(2)
8300            add_tensor_3 = torch.ops.aten.add.Tensor(broadcast_in_dim_default_2, 1e-05)
8301            return (var_default, sum_default_1, add_tensor_3)
8302
8303        inps = [
8304            (torch.Size([1024]), torch.float32),
8305            (torch.Size([1024]), torch.float32),
8306            (torch.Size([1, 512, 1024]), torch.float32),
8307            (torch.Size([1, 512, 1024]), torch.float32),
8308            (torch.Size([1, 512, 1]), torch.float32),
8309            (torch.Size([1, 512, 1]), torch.float32),
8310        ]
8311        inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps]
8312        self.common(forward, inps, atol=1e-05, rtol=2e-05)
8313
8314    @unittest.skipIf(
8315        os.environ.get("BUILD_ENVIRONMENT", "").startswith("parallelnative"),
8316        "TODO: debug this with asan",
8317    )
8318    def test_tmp_not_defined_issue2(self):
8319        def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
8320            div_tensor_7 = torch.ops.aten.div.Tensor(getitem_17, arg81_1)
8321            mul_tensor_24 = torch.ops.aten.mul.Tensor(div_tensor_7, arg38_1)
8322            sum_default_7 = torch.ops.aten.sum.default(mul_tensor_24)
8323            return (new_zeros_default_4, sum_default_7)
8324
8325        dtype = torch.float32
8326        args = [
8327            ((1, 88, 40, 40), (140800, 1600, 40, 1), dtype),
8328            ((), (), dtype),
8329            ((1, 88, 40, 40), (140800, 1600, 40, 1), dtype),
8330            ((3,), (1,), dtype),
8331        ]
8332        args = [
8333            rand_strided(shape, stride, dtype).requires_grad_(True).add(1)
8334            for shape, stride, dtype in args
8335        ]
8336        self.common(forward, args)
8337
8338    @requires_gpu()
8339    def test_tmp_not_defined_issue3(self):
8340        from torch import device
8341
8342        def forward(
8343            self,
8344            primals_1: "f32[1001, 6]",
8345            primals_2: "f32[1001]",
8346            primals_3: "f32[1001, 64]",
8347            primals_4: "f32[4190]",
8348            primals_5: "f32[4190]",
8349            primals_6: "f32[1739, 4190]",
8350            primals_48: "f32[6144, 4191]",
8351        ):
8352            _tensor_constant0: "i64[4190]" = self._tensor_constant0
8353            lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default(
8354                _tensor_constant0
8355            )
8356
8357            index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor(
8358                primals_48, [None, lift_fresh_copy]
8359            )
8360
8361            _tensor_constant1: "i64[6]" = self._tensor_constant1
8362            lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default(
8363                _tensor_constant1
8364            )
8365            index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor(
8366                primals_48, [None, lift_fresh_copy_1]
8367            )
8368            primals_48 = lift_fresh_copy_1 = None
8369            permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0])
8370            addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default(
8371                primals_2, index_1, permute
8372            )
8373            amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True)
8374            sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax)
8375            exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub)
8376            sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
8377            div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1)
8378
8379            full_default: "i32[6144, 1001]" = torch.ops.aten.full.default(
8380                [6144, 1001],
8381                1,
8382                dtype=torch.int32,
8383                layout=torch.strided,
8384                device=device(type=GPU_TYPE, index=0),
8385                pin_memory=False,
8386            )
8387
8388            iota: "i32[1001]" = torch.ops.prims.iota.default(
8389                1001,
8390                start=0,
8391                step=1,
8392                dtype=torch.int32,
8393                device=device(type=GPU_TYPE),
8394                requires_grad=False,
8395            )
8396
8397            mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota)
8398            iota_1: "i32[6144]" = torch.ops.prims.iota.default(
8399                6144,
8400                start=0,
8401                step=1001,
8402                dtype=torch.int32,
8403                device=device(type=GPU_TYPE, index=0),
8404                requires_grad=False,
8405            )
8406            view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
8407            view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1])
8408            _embedding_bag = torch.ops.aten._embedding_bag.default(
8409                primals_3, view, iota_1, False, 0, False, view_1
8410            )
8411            getitem: "f32[6144, 64]" = _embedding_bag[0]
8412            getitem_1: "i32[6150144]" = _embedding_bag[1]
8413            getitem_2: "i32[6144]" = _embedding_bag[2]
8414            getitem_3: "i32[0]" = _embedding_bag[3]
8415            unsqueeze: "f32[6144, 1, 64]" = torch.ops.aten.unsqueeze.default(getitem, 1)
8416            var_mean = torch.ops.aten.var_mean.correction(
8417                index, [1], correction=0, keepdim=True
8418            )
8419            getitem_4: "f32[6144, 1]" = var_mean[0]
8420            getitem_5: "f32[6144, 1]" = var_mean[1]
8421            add: "f32[6144, 1]" = torch.ops.aten.add.Tensor(getitem_4, 1e-05)
8422            rsqrt: "f32[6144, 1]" = torch.ops.aten.rsqrt.default(add)
8423            sub_1: "f32[6144, 4190]" = torch.ops.aten.sub.Tensor(index, getitem_5)
8424            mul_1: "f32[6144, 4190]" = torch.ops.aten.mul.Tensor(sub_1, rsqrt)
8425            mul_2: "f32[6144, 4190]" = torch.ops.aten.mul.Tensor(mul_1, primals_4)
8426            add_1: "f32[6144, 4190]" = torch.ops.aten.add.Tensor(mul_2, primals_5)
8427            permute_1: "f32[4190, 1739]" = torch.ops.aten.permute.default(
8428                primals_6, [1, 0]
8429            )
8430
8431            return [
8432                index,
8433                index_1,
8434                addmm,
8435                amax,
8436                sum_1,
8437                iota_1,
8438                view,
8439                view_1,
8440                getitem_1,
8441                getitem_2,
8442                getitem_3,
8443                unsqueeze,
8444                getitem_5,
8445                rsqrt,
8446                add_1,
8447                permute_1,
8448            ]
8449
8450        kwargs = aot_graph_input_parser(forward, device=GPU_TYPE)
8451        self.common(forward, [], kwargs=kwargs)
8452
8453    def test_misaligned_address_issue1(self):
8454        def forward(sub_tensor_1, unsqueeze_default):
8455            gather_default = torch.ops.aten.gather.default(
8456                sub_tensor_1, 1, unsqueeze_default
8457            )
8458            return gather_default
8459
8460        args = [
8461            ((1, 1000), (1000, 1), torch.float32),
8462            ((1, 1), (1, 1), torch.int64),
8463        ]
8464        args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
8465        self.common(forward, args)
8466
8467    def test_invalid_operand_issue1(self):
8468        def forward(arg0_1, arg1_1, arg3_1, squeeze, view_1, slice_1):
8469            slice_scatter = torch.ops.aten.slice_scatter.default(
8470                slice_1, arg3_1, 1, 1, 9223372036854775807
8471            )
8472            slice_scatter_1 = torch.ops.aten.slice_scatter.default(
8473                arg1_1, slice_scatter, 0, 0, 9223372036854775807
8474            )
8475            slice_2 = torch.ops.aten.slice.Tensor(
8476                slice_scatter_1, 0, 0, 9223372036854775807
8477            )
8478            select_scatter = torch.ops.aten.select_scatter.default(
8479                slice_2, squeeze, 1, 0
8480            )
8481            slice_scatter_2 = torch.ops.aten.slice_scatter.default(
8482                slice_scatter_1, select_scatter, 0, 0, 9223372036854775807
8483            )
8484            view = torch.ops.aten.view.default(slice_scatter_2, [-1, 128])
8485            embedding = torch.ops.aten.embedding.default(arg0_1, view, 1)
8486            return [embedding, view_1]
8487
8488        args = [
8489            ((50005, 768), (768, 1), torch.float32),
8490            ((8, 128), (128, 1), torch.int64),
8491            ((8, 127), (127, 1), torch.int64),
8492            ((8,), (1,), torch.int64),
8493            ((1024,), (1,), torch.int64),
8494            ((8, 128), (128, 1), torch.int64),
8495        ]
8496        args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
8497        self.common(forward, args)
8498
8499    def test_sizehint_issue1(self):
8500        def forward(x):
8501            return torch.nn.functional.unfold(
8502                x, kernel_size=[4, 4], dilation=1, padding=0, stride=[4, 4]
8503            )
8504
8505        args = [((2, 24, 56, 56), (75264, 3136, 56, 1), torch.float32, False)]
8506        args = [
8507            rand_strided(sh, st, dt).requires_grad_(rg) for (sh, st, dt, rg) in args
8508        ]
8509        self.common(forward, args)
8510
8511    def test_zero_dim_reductions(self):
8512        for kd in [True, False]:
8513            inps0 = (torch.zeros(2, 0, device=self.device, dtype=torch.float16), 1, kd)
8514            failed_ops = [aten.argmin, aten.argmax, aten.max, aten.min]
8515            for fo in failed_ops:
8516                with self.assertRaisesRegex(
8517                    IndexError, "Expected reduction dim 1 to have non-zero size"
8518                ):
8519                    mod = make_fx(fo)(*inps0)
8520                    _ = compile_fx_inner(mod, inps0)
8521
8522            pass_ops = [
8523                lambda *x: fn(*x) for fn in [aten.sum, aten.prod, aten.any, aten.all]
8524            ]
8525            for po in pass_ops:
8526                compiled = torch._dynamo.optimize("inductor")(po)
8527                expected = po(*inps0)
8528                actual = compiled(*inps0)
8529
8530            self.assertTrue(torch.allclose(actual, expected, atol=1e-3, rtol=1e-3))
8531
8532    def test_unfold_zero_dimension_tensor(self):
8533        def forward(x):
8534            return torch.unfold_copy(dimension=1, input=x, size=0, step=7)
8535
8536        x = torch.rand([1, 0], dtype=torch.float32)
8537
8538        y = forward(x)
8539        compiled_y = torch.compile(forward, fullgraph=True)(x)
8540
8541        self.assertEqual(y, compiled_y)
8542
8543    def test_zero_element_mutation(self):
8544        class CustomModel(nn.Module):
8545            def __init__(self):
8546                super().__init__()
8547                self.layer1 = nn.LeakyReLU(negative_slope=5.2955089, inplace=True)
8548
8549            def forward(self, inputs):
8550                return self.layer1(inputs)
8551
8552        ip_size = [0]
8553        input_tensor = torch.randn(ip_size)
8554
8555        mymodel = CustomModel()
8556        self.common(mymodel, (input_tensor,))
8557
8558    def test_lerp(self):
8559        # non-contiguous inputs for lerp
8560        def fn0(i0, i1):
8561            x1 = i0.transpose(-2, -3)
8562            return torch.lerp(i1, x1, 70000)
8563
8564        # contiguous inputs for lerp
8565        def fn1(i0, i1):
8566            return torch.lerp(i1, i0, 70000)
8567
8568        self.common(fn0, [torch.rand(10, 3, 10), torch.rand(3, 10, 10)])
8569        self.common(fn1, [torch.rand(3, 10, 10), torch.rand(3, 10, 10)])
8570
8571    def test_unspec_inputs(self):
8572        if self.device == "cpu":
8573            raise unittest.SkipTest("Testing mixed devices")
8574
8575        def fn(x, y):
8576            return x + y, x * y, x / y
8577
8578        opt = torch._dynamo.optimize("inductor")(fn)
8579        dtypes = [
8580            torch.float16,
8581            torch.bfloat16,
8582            torch.float32,
8583            torch.float64,
8584            torch.int32,
8585            torch.int64,
8586        ]
8587
8588        for d in dtypes:
8589            inputs = (
8590                rand_strided((2, 3), (3, 1), dtype=torch.float32, device=GPU_TYPE),
8591                rand_strided((), (), dtype=d, device="cpu"),
8592            )
8593            self.assertTrue(same(opt(*inputs), fn(*inputs)))
8594            inputs = (inputs[1], inputs[0])
8595            self.assertTrue(same(opt(*inputs), fn(*inputs)))
8596
8597    @dynamo_config.patch(automatic_dynamic_shapes=True)
8598    def test_list_clearing(self):
8599        if self.device == "cpu":
8600            contexts = [contextlib.nullcontext]
8601        else:
8602            contexts = [
8603                contextlib.nullcontext,
8604                lambda: config.patch({"triton.cudagraphs": True}),
8605            ]
8606
8607        for context in contexts:
8608            with context():
8609                inps = [
8610                    torch.rand([5, 5]).to(self.device),
8611                    torch.rand([5, 5]).to(self.device),
8612                ]
8613                inp_refs = [weakref.ref(inp) for inp in inps]
8614
8615                def fn(x, y):
8616                    a = x + y
8617                    return (a @ a,)
8618
8619                fn_fx = make_fx(fn)(inps[0], inps[1])
8620                fn_compiled = compile_fx_inner(fn_fx, inps)
8621
8622                test_self = self
8623                matmul_seen = False
8624
8625                class TestRefMode(TorchDispatchMode):
8626                    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
8627                        kwargs = kwargs if kwargs else {}
8628
8629                        nonlocal inps
8630                        nonlocal inp_refs
8631                        nonlocal test_self
8632                        nonlocal matmul_seen
8633
8634                        # by matmul, inputs should be deallocated
8635                        # TODO: should not be necessary, ref-cycle ?
8636                        gc.collect()
8637                        if func is aten.mm.out:
8638                            matmul_seen = True
8639                            test_self.assertEqual(len(inps), 0)
8640                            test_self.assertIsNone(inp_refs[0]())
8641                            test_self.assertIsNone(inp_refs[1]())
8642
8643                        return func(*args, **kwargs)
8644
8645                with TestRefMode():
8646                    fn_compiled(inps)
8647
8648                # do an extra run to make sure we are deallocating on warmup and record
8649                if self.device == GPU_TYPE:
8650                    inps.extend(
8651                        [
8652                            torch.rand([5, 5]).to(self.device),
8653                            torch.rand([5, 5]).to(self.device),
8654                        ]
8655                    )
8656                    inp_refs.extend([weakref.ref(inp) for inp in inps])
8657                    matmul_seen = False
8658
8659                    with TestRefMode():
8660                        fn_compiled(inps)
8661
8662                # for some reason, TorchDispatch doesnt capture the
8663                # cuda mm call (even without cudagraphs)
8664                if self.device == "cpu":
8665                    self.assertTrue(matmul_seen)
8666                else:
8667                    self.assertEqual(len(inps), 0)
8668
8669    def test_dtype_mismatch_issue(self):
8670        def fn(x):
8671            attn = torch.nn.functional.pad(x, [0, 1])
8672            return attn.softmax(dim=-1)
8673
8674        x = torch.rand(128, 32, 63)
8675        self.common(fn, (x,))
8676
8677    def test_diagonal_copy(self):
8678        def fn(x):
8679            return torch.diagonal_copy(x)
8680
8681        for x in (torch.randn(2, 3), torch.randn(2, 2), torch.randn(3, 2)):
8682            self.common(fn, (x,))
8683
8684    def test_kwargs(self):
8685        if self.device == GPU_TYPE:
8686            raise unittest.SkipTest("histogramdd only supports cpu")
8687
8688        def fn(x, y):
8689            return torch.histogramdd(
8690                x,
8691                bins=[3, 3],
8692                weight=y,
8693            )
8694
8695        self.common(
8696            fn,
8697            [torch.randn((4, 2)), torch.randn(4)],
8698        )
8699
8700    # Shape padding causes the inputs to all get specialized, so the codegen
8701    # test fails
8702    @expectedFailureCodegenDynamic
8703    @requires_gpu()
8704    @torch._inductor.config.patch("shape_padding", True)
8705    def test_shape_padding(self):
8706        dtypes = [
8707            torch.float16,
8708            torch.float32,
8709        ]
8710
8711        b, m, n, k = 7, 11, 13, 15
8712
8713        def gen(*shape, dtype=torch.float32):
8714            return torch.randn(*shape, device=GPU_TYPE, dtype=dtype) / k + 1.0
8715
8716        for dtype in dtypes:
8717            x = gen(m, k, dtype=dtype)
8718            y = gen(k, n, dtype=dtype)
8719            z = gen(n, dtype=dtype)
8720            self.common(lambda x, y: torch.mm(x, y), (x, y))
8721            self.common(lambda x, y: torch.matmul(x, y), (x, y))
8722            self.common(lambda x, y, z: torch.addmm(z, x, y), (x, y, z))
8723
8724        for dtype in dtypes:
8725            x = gen(b, m, k, dtype=dtype)
8726            y = gen(b, k, n, dtype=dtype)
8727            z = gen(n, dtype=dtype)
8728            self.common(lambda x, y: torch.bmm(x, y), (x, y))
8729            self.common(lambda x, y: torch.matmul(x, y), (x, y))
8730            self.common(lambda x, y, z: torch.baddbmm(z, x, y), (x, y, z))
8731
8732    @requires_gpu()
8733    @torch._inductor.config.patch("layout_optimization", True)
8734    def test_inductor_layout_optimization_input_mutations(self):
8735        # channel dim must be > 64 for inductor to do layout optimization and use NHWC
8736        mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).to(GPU_TYPE)
8737
8738        def f(x):
8739            x.mul_(2)
8740            out = mod(x)
8741            return out
8742
8743        f_compiled = torch.compile(f)
8744        x_ref = torch.rand(2, 3, 128, 128, device=GPU_TYPE)
8745        x_test = x_ref.clone().detach()
8746        with torch.no_grad():
8747            out_ref = f(x_ref)
8748            out_test = f_compiled(x_test)
8749            self.assertEqual(out_ref, out_test)
8750            self.assertEqual(out_ref.shape, out_test.shape)
8751            # Importantly, since inductor._config.keep_output_stride is True,
8752            # the outputs should have matching strides here.
8753            self.assertEqual(out_ref.stride(), out_test.stride())
8754            self.assertEqual(x_ref, x_test)
8755
8756    def test_int_input_dynamic_shapes(self):
8757        @torch.compile(dynamic=True)
8758        def fn(x, i):
8759            y = x * i
8760            return y
8761
8762        # Constant must not get matched as constant
8763        self.common(fn, [torch.randn(3, 1, 1, 1, 1), 9132])
8764
8765    def test_sqrt_dynamic_shapes(self):
8766        # TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877.
8767        # TODO: support cuda path.
8768        if self.device == GPU_TYPE:
8769            raise unittest.SkipTest("sqrt dynamic shapes only supports cpu")
8770
8771        class Model(torch.nn.Module):
8772            def __init__(self):
8773                super().__init__()
8774
8775            def forward(self, x):
8776                B, N, C = x.shape
8777                return self.get_rel_indices(N)
8778
8779            def get_rel_indices(self, num_patches: int) -> torch.Tensor:
8780                img_size = int(num_patches**0.5)
8781                ind = torch.arange(img_size)
8782                return ind
8783
8784        self.common(
8785            Model(),
8786            [
8787                torch.randn(8, 4, 4),
8788            ],
8789        )
8790
8791    def test_rsqrt_dynamic_shapes(self):
8792        # From HF hf_BigBird model.
8793        @torch.compile(dynamic=True)
8794        def fn(a, b):
8795            r = 1 / math.sqrt(a.size(1))
8796            return torch.bmm(a, b) / r
8797
8798        self.common(
8799            fn,
8800            [
8801                torch.randn(2, 4, 4),
8802                torch.randn(2, 4, 4),
8803            ],
8804        )
8805
8806    def test_index_dynamic_shapes(self):
8807        # Repro from vision_maskrcnn
8808        def fn(arg0_1):
8809            unsqueeze = arg0_1.unsqueeze(0)
8810            sym_size = arg0_1.size(1)
8811            ceil = math.ceil(sym_size * 1.8735363483428955)
8812            iota = torch.ops.prims.iota.default(
8813                ceil,
8814                start=0,
8815                step=1,
8816                dtype=torch.int64,
8817                device=arg0_1.device,
8818                requires_grad=False,
8819            )
8820            convert_element_type_1 = iota.to(torch.float32)
8821            sym_size_1 = arg0_1.size(2)
8822            floor_1 = math.floor(sym_size_1 * 1.8735363483428955)
8823            ceil_1 = math.ceil(floor_1)
8824            iota_1 = torch.ops.prims.iota.default(
8825                ceil_1,
8826                start=0,
8827                step=1,
8828                dtype=torch.int64,
8829                device=arg0_1.device,
8830                requires_grad=False,
8831            )
8832            convert_element_type_3 = iota_1.to(torch.float32)
8833            sub_2 = (convert_element_type_1 + 0.5) * (sym_size / ceil) - 0.5
8834            clamp_min = sub_2.clamp_min(0.0)
8835            sub_3 = (convert_element_type_3 + 0.5) * (sym_size_1 / floor_1) - 0.5
8836            clamp_min_1 = sub_3.clamp_min(0.0)
8837            convert_element_type_4 = clamp_min.to(torch.int64)
8838            sub_4 = sym_size - 1
8839            clamp_max = clamp_min.ceil().clamp_max(sub_4)
8840            convert_element_type_5 = clamp_max.to(torch.int64)
8841            convert_element_type_6 = clamp_min_1.to(torch.int64)
8842            unsqueeze_2 = convert_element_type_4.unsqueeze(1)
8843            index = torch.ops.aten.index.Tensor(
8844                unsqueeze, [None, None, unsqueeze_2, convert_element_type_6]
8845            )
8846            index_1 = torch.ops.aten.index.Tensor(
8847                unsqueeze,
8848                [
8849                    None,
8850                    None,
8851                    convert_element_type_5.unsqueeze(1),
8852                    convert_element_type_6,
8853                ],
8854            )
8855            sub_6 = clamp_min.unsqueeze(1) - unsqueeze_2
8856            mul_10 = (index * (1.0 - sub_6) + index_1 * (sub_6)) * (
8857                1.0 - (clamp_min_1 - convert_element_type_6)
8858            )
8859            select = torch.ops.aten.select.int(mul_10, 0, 0)
8860            return (select,)
8861
8862        x = torch.randn(15, 20, 3)
8863        self.common(
8864            fn,
8865            [x],
8866        )
8867
8868    def test_setitem_with_int_parameter(self):
8869        x = torch.zeros(7, device=self.device)
8870
8871        def fn(n, a):
8872            a[n] = -1
8873            return a
8874
8875        cnts = CompileCounterWithBackend("inductor")
8876        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
8877
8878        for n in range(2, x.shape[0]):
8879            opt_fn(n, x)
8880            self.assertEqual(x[n], -1)
8881
8882        # If assume_static_by_default is set, the calls above will trigger
8883        # 3 function compilation:
8884        #   1. assuming 'n' is static (equals 2)
8885        #   2. making 'n' dynamic, but with the guard 'end <= x.shape[0]'
8886        #      (from: torch._inductor.ir.SliceView.create)
8887        frame_count = 2 if torch._dynamo.config.assume_static_by_default else 1
8888        self.assertEqual(cnts.frame_count, frame_count)
8889
8890        # Negative index triggers new compilation.
8891        opt_fn(-x.shape[0], x)
8892        self.assertEqual(x[0], -1)
8893        self.assertEqual(cnts.frame_count, frame_count + 1)
8894
8895    @config.patch(profiler_mark_wrapper_call=True)
8896    def test_profiler_mark_wrapper_call(self):
8897        from torch.profiler import profile
8898
8899        @torch._dynamo.optimize("inductor", nopython=True)
8900        def fn(a, b):
8901            return a + b
8902
8903        a = torch.rand((100,))
8904        b = torch.rand((100,))
8905        with profile() as prof:
8906            fn(a, b)
8907        assert any(
8908            "inductor_wrapper_call" in e.name for e in prof.profiler.function_events
8909        )
8910
8911    def test_insignificant_strides(self):
8912        def f(x):
8913            tmp = x + 1
8914            return tmp.view(-1, 1, 2)
8915
8916        x = torch.arange(8, device=self.device, dtype=torch.float32)
8917        out = f(x)
8918        compiled_out = torch.compile(f)(x)
8919
8920        self.assertEqual(out.stride(), compiled_out.stride())
8921        self.assertEqual(out, compiled_out)
8922
8923    @unittest.skipIf(IS_X86 and not HAS_AVX2, "Requires AVX2")
8924    def test_pixel_shuffle_channels_last(self):
8925        def fn(x):
8926            x = torch.nn.functional.pixel_shuffle(x, 2)
8927            x = torch.nn.functional.relu(x)
8928            return x
8929
8930        self.common(
8931            fn,
8932            (torch.randn(1, 16, 64, 72).to(memory_format=torch.channels_last),),
8933        )
8934
8935    def test_where_broadcast(self):
8936        # https://github.com/pytorch/pytorch/issues/93374
8937        def fn(x, p1, p0):
8938            o = torch.where(x, p1, p0)
8939            return o
8940
8941        # https://github.com/pytorch/pytorch/issues/94725
8942        class Repro(torch.nn.Module):
8943            def __init__(self):
8944                super().__init__()
8945                self.register_buffer(
8946                    "_tensor_constant0", torch.randn([], dtype=torch.float32)
8947                )
8948
8949            def forward(self, arg0_1, arg1_1):
8950                convert_element_type = torch.ops.prims.convert_element_type.default(
8951                    arg1_1, torch.bool
8952                )
8953                bitwise_not = torch.ops.aten.bitwise_not.default(convert_element_type)
8954                _tensor_constant0 = self._tensor_constant0
8955                lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(
8956                    _tensor_constant0
8957                )
8958                where = torch.ops.aten.where.self(bitwise_not, lift_fresh_copy, arg0_1)
8959                return (where, bitwise_not)
8960
8961        self.common(
8962            fn,
8963            (torch.tensor([[True]]), torch.rand(13, 7, 3), torch.rand(1, 1)),
8964        )
8965
8966        args = [
8967            torch.randn(1, 4, 64, 64),
8968            torch.zeros(1, 1, 64, 64, dtype=torch.uint8),
8969        ]
8970        args[1][:, :, :32, :32] = 1
8971        eager_args = [x.clone() for x in args]
8972        eager_mod = Repro()
8973        mod = make_fx(eager_mod, tracing_mode="real")(*args)
8974        compiled = compile_fx_inner(mod, args)
8975        inductor_out = compiled(args)
8976        eager_out = eager_mod(*eager_args)
8977        self.assertEqual(inductor_out, eager_out)
8978
8979    @skipIfRocm
8980    def test_require_stride_expanded(self):
8981        def forward(arg6, arg7, arg16):
8982            convolution = torch.ops.aten.convolution(
8983                arg16.unsqueeze(0), arg7, arg6, [4, 4], [2, 2], [1, 1], False, [0, 0], 1
8984            )
8985            return (convolution,)
8986
8987        self.common(
8988            forward,
8989            (
8990                None,
8991                rand_strided(
8992                    (64, 3, 11, 11),
8993                    (363, 121, 11, 1),
8994                    torch.float32,
8995                    device=self.device,
8996                ).to(memory_format=torch.channels_last),
8997                rand_strided(
8998                    (1, 3, 224, 224),
8999                    (150528, 50176, 224, 1),
9000                    torch.float32,
9001                    device=self.device,
9002                )
9003                .to(memory_format=torch.channels_last)
9004                .squeeze(0),
9005            ),
9006            atol=1e-3,
9007            rtol=0.001,
9008        )
9009
9010        # expanded dim should not cause copy in require_stride_order
9011        assertGeneratedKernelCountEqual(self, 0)
9012
9013    @requires_gpu()
9014    @unittest.skipIf(
9015        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
9016        "Does not support SDPA or pre-SM80 hardware",
9017    )
9018    @skipIfRocm
9019    def test_sdpa(self):
9020        def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
9021            view = torch.ops.aten.view.default(arg3_1, [23760, 128])
9022            arg3_1 = None
9023            mm = torch.ops.aten.mm.default(view, arg4_1)
9024            view = arg4_1 = None
9025            view_1 = torch.ops.aten.view.default(mm, [3, 99, 80, 8])
9026            mm = None
9027            view_2 = torch.ops.aten.view.default(view_1, [3, 99, 80, 8])
9028            view_1 = None
9029            permute = torch.ops.aten.permute.default(view_2, [0, 3, 1, 2])
9030            view_2 = None
9031            view_3 = torch.ops.aten.view.default(permute, [3, 8, 99, 80])
9032            permute = None
9033
9034            clone = torch.ops.aten.clone.default(
9035                view_3, memory_format=torch.contiguous_format
9036            )
9037            view_3 = None
9038
9039            expand = torch.ops.aten.expand.default(clone, [3, 8, 99, 80])
9040            clone = None
9041            _scaled_dot_product_efficient_attention = (
9042                torch.ops.aten._scaled_dot_product_efficient_attention.default(
9043                    arg0_1, arg1_1, arg2_1, expand, False
9044                )
9045            )
9046            arg0_1 = arg1_1 = arg2_1 = expand = None
9047            getitem = _scaled_dot_product_efficient_attention[0]
9048            _scaled_dot_product_efficient_attention = None
9049            return (getitem,)
9050
9051        DEVICE = torch.device(f"{GPU_TYPE}:0")
9052        DTYPE = torch.float16
9053        B = 3
9054        H = 8
9055        Q = 99
9056        K = 80
9057        D = 32
9058        C_bias = 128
9059
9060        # inputs
9061        query = torch.randn((B, H, Q, D), device=DEVICE, dtype=DTYPE)
9062        key = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE)
9063        value = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE)
9064        bias = torch.randn((B, Q, K, C_bias), device=DEVICE, dtype=DTYPE)
9065        weights = torch.randn((C_bias, H), device=DEVICE, dtype=DTYPE)
9066
9067        self.common(
9068            foo,
9069            (query, key, value, bias, weights),
9070            atol=0.02,
9071            rtol=1e4,
9072        )
9073
9074    @requires_gpu()
9075    @unittest.skipIf(
9076        not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
9077        "Does not support mem_eff_attention",
9078    )
9079    @skipIfRocm
9080    def test_sdpa_unaligned_mask(self):
9081        def foo(
9082            arg0_1: "f32[8, 8, 16, 16]",
9083            arg1_1: "f32[8, 8, 15, 16]",
9084            arg2_1: "f32[8, 8, 15, 16]",
9085            arg3_1: "f32[1, 1, 16, 15]",
9086        ):
9087            constant_pad_nd: "f32[1, 1, 16, 16]" = (
9088                torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0)
9089            )
9090            arg3_1 = None
9091            slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor(
9092                constant_pad_nd, -1, 0, 15
9093            )
9094            constant_pad_nd = None
9095            expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default(
9096                slice_1, [8, 8, 16, 15]
9097            )
9098            slice_1 = None
9099            _scaled_dot_product_efficient_attention = (
9100                torch.ops.aten._scaled_dot_product_efficient_attention.default(
9101                    arg0_1, arg1_1, arg2_1, expand, False
9102                )
9103            )
9104            arg0_1 = arg1_1 = arg2_1 = expand = None
9105            getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0]
9106            _scaled_dot_product_efficient_attention = None
9107            return (getitem,)
9108
9109        query = torch.rand(8, 8, 16, 16, device=GPU_TYPE)
9110        key = torch.rand(8, 8, 15, 16, device=GPU_TYPE)
9111        value = torch.rand(8, 8, 15, 16, device=GPU_TYPE)
9112        bias = torch.rand(1, 1, 16, 15, device=GPU_TYPE)
9113        self.common(
9114            foo,
9115            (query, key, value, bias),
9116            atol=0.02,
9117            rtol=1e4,
9118        )
9119
9120    @requires_gpu()
9121    @unittest.skipIf(
9122        not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
9123        "Does not support mem_eff_attention",
9124    )
9125    @skipIfRocm
9126    @config.patch(freezing=True)
9127    def test_sdpa_unaligned_mask_freezing(self):
9128        class Mod(torch.nn.Module):
9129            def __init__(self):
9130                super().__init__()
9131                self.arg3_1 = torch.rand(1, 1, 16, 15, device=GPU_TYPE)
9132
9133            def forward(
9134                self,
9135                arg0_1: "f32[8, 8, 16, 16]",
9136                arg1_1: "f32[8, 8, 15, 16]",
9137                arg2_1: "f32[8, 8, 15, 16]",
9138            ):
9139                arg3_1 = self.arg3_1
9140                constant_pad_nd: "f32[1, 1, 16, 16]" = (
9141                    torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0)
9142                )
9143                arg3_1 = None
9144                slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor(
9145                    constant_pad_nd, -1, 0, 15
9146                )
9147                constant_pad_nd = None
9148                expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default(
9149                    slice_1, [8, 8, 16, 15]
9150                )
9151                slice_1 = None
9152                _scaled_dot_product_efficient_attention = (
9153                    torch.ops.aten._scaled_dot_product_efficient_attention.default(
9154                        arg0_1, arg1_1, arg2_1, expand, False
9155                    )
9156                )
9157                arg0_1 = arg1_1 = arg2_1 = expand = None
9158                getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[
9159                    0
9160                ]
9161                _scaled_dot_product_efficient_attention = None
9162                return (getitem,)
9163
9164        query = torch.rand(8, 8, 16, 16, device=GPU_TYPE)
9165        key = torch.rand(8, 8, 15, 16, device=GPU_TYPE)
9166        value = torch.rand(8, 8, 15, 16, device=GPU_TYPE)
9167
9168        mod = Mod()
9169        out_eager = mod(query, key, value)
9170
9171        with torch.no_grad():
9172            out_compiled = torch.compile(mod)(query, key, value)
9173            self.assertEqual(out_eager, out_compiled, atol=0.02, rtol=1e4)
9174
9175    def test_where_with_logical_op(self):
9176        def fn_and(x, y):
9177            return torch.where(torch.logical_and(x, y), 1.0, 0.0)
9178
9179        def fn_or(x, y):
9180            return torch.where(torch.logical_or(x, y), 1.0, 0.0)
9181
9182        self.common(
9183            fn_and,
9184            (torch.randn(32), torch.randn(32)),
9185        )
9186        self.common(
9187            fn_or,
9188            (torch.randn(32), torch.randn(32)),
9189        )
9190
9191    @skipIfRocm
9192    def test_conv_with_as_strided(self):
9193        class Model(nn.Module):
9194            def __init__(self):
9195                super().__init__()
9196                self.kv = torch.nn.Conv2d(
9197                    256, 384, kernel_size=(1, 1), stride=(1, 1), bias=False
9198                )
9199
9200            def forward(self, x):
9201                convolution = self.kv(x)
9202                constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
9203                    convolution, [2, 2, 2, 2], 0.0
9204                )
9205                # as_strided inputs are depend on input's size and stide.
9206                as_strided = torch.ops.aten.as_strided.default(
9207                    constant_pad_nd, [8, 384, 2, 20, 12], [153600, 400, 160, 1, 20]
9208                )
9209                as_strided_1 = torch.ops.aten.as_strided.default(
9210                    as_strided, [8, 384, 2, 2, 12, 12], [153600, 400, 160, 8, 20, 1]
9211                )
9212                clone = torch.ops.aten.clone.default(
9213                    as_strided_1, memory_format=torch.contiguous_format
9214                )
9215                return clone
9216
9217        self.common(
9218            Model(),
9219            (torch.randn(8, 256, 16, 16),),
9220        )
9221
9222    def test_inplace_where_pointwise(self):
9223        # https://github.com/pytorch/pytorch/issues/96446
9224        def fn(a, b):
9225            a[0] = 2
9226            return a * b
9227
9228        self.common(fn, (torch.rand(1), torch.rand(2)))
9229
9230    def test_view_on_aliased(self):
9231        # https://github.com/pytorch/pytorch/issues/96728
9232        def fn1(a, b):
9233            a = a.max(0).values
9234            c = torch.cat((a, b))
9235            c = c.round()
9236            b >= a[0]  # noqa: B015
9237            return c
9238
9239        some_const = torch.tensor(6324)
9240
9241        def fn2():
9242            a = torch.tensor([[0.6324]])
9243            ret = torch.cat((a, a), dim=0)
9244            some_const >= a[0]  # noqa: B015
9245            return ret
9246
9247        self.common(fn1, (torch.tensor([[4.0]]), torch.tensor([5.0])))
9248        self.common(fn2, ())
9249
9250    def test_argmax_to_float(self):
9251        # https://github.com/pytorch/pytorch/issues/97127
9252        def fn():
9253            a = torch.zeros([2, 2])
9254            b = a.argmax(0)
9255            return b.float().mean()
9256
9257        self.common(fn, ())
9258
9259    def test_const_int32_to_float(self):
9260        # https://github.com/pytorch/pytorch/issues/97124
9261        def fn():
9262            a = torch.zeros([1, 2], dtype=torch.int32)
9263            a = a + a
9264            b = a.to(dtype=torch.float32)
9265            return b * 0.8
9266
9267        self.common(fn, ())
9268
9269    def test_getitem(self):
9270        out_features = ["p3", "p4", "p5", "p6", "p7"]
9271        in_feature = "p5"
9272
9273        def fn(a):
9274            return a[out_features.index(in_feature)]
9275
9276        x = [
9277            torch.rand([1, 256, 100, 152], device=self.device),
9278            torch.rand([1, 256, 50, 76], device=self.device),
9279            torch.rand([1, 256, 25, 38], device=self.device),
9280        ]
9281        opt_fn = torch._dynamo.optimize("inductor")(fn)
9282        same(fn(x), opt_fn(x))
9283
9284    def test_pad_view(self):
9285        def fn(a):
9286            y = torch.nn.functional.pad(a, (0, 0, 0, 1))
9287            y = y.view(*y.size()[:-2], y.size(-1), y.size(-2))
9288            return y
9289
9290        x = torch.rand(48, 3, 512, 512)
9291        self.common(fn, (x,))
9292
9293    def test_pad_cast(self):
9294        def fn(x):
9295            return torch.nn.functional.pad(x.to(torch.float32), (0, 3, 0, 0))
9296
9297        for dtype in [torch.int32, torch.int64]:
9298            self.common(fn, (torch.ones(1, 1, 13, dtype=dtype),))
9299
9300    @unittest.skipIf(not HAS_CPU, "requires C++ compiler")
9301    def test_data_type_propogation(self):
9302        from torch._dynamo.utils import detect_fake_mode
9303        from torch._inductor.codegen.common import boolean_ops
9304        from torch._inductor.compile_fx import _shape_env_from_inputs
9305        from torch._inductor.debug import DebugContext
9306        from torch._inductor.decomposition import decompositions
9307        from torch._inductor.graph import GraphLowering
9308        from torch._inductor.virtualized import V
9309        from torch.fx.passes.fake_tensor_prop import FakeTensorProp
9310
9311        def get_data_type(node: torch.fx.Node):
9312            if OptimizationContext.key in node.meta:
9313                return node.meta[OptimizationContext.key].dtype
9314            else:
9315                return None
9316
9317        def func(arg0_1):
9318            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
9319                arg0_1, [3, 3], [2, 2], [1, 1]
9320            )
9321            arg0_1 = None
9322            getitem = max_pool2d_with_indices[0]
9323            max_pool2d_with_indices = None
9324            return (getitem,)
9325
9326        example_inputs = [
9327            torch.randn(10, 32, 20, 20, dtype=torch.bfloat16).to(
9328                memory_format=torch.channels_last
9329            )
9330        ]
9331
9332        gm = make_fx(func, decomposition_table=decompositions, tracing_mode="fake")(
9333            *example_inputs
9334        )
9335
9336        shape_env = _shape_env_from_inputs(example_inputs)
9337
9338        fake_mode = detect_fake_mode(example_inputs)
9339        if not fake_mode:
9340            fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
9341            FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
9342        else:
9343            FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
9344                *example_inputs
9345            )
9346        with V.set_fake_mode(fake_mode):
9347            graph = GraphLowering(
9348                gm,
9349                shape_env=shape_env,
9350            )
9351            with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
9352                graph.run(*example_inputs)
9353                graph.compile_to_module()
9354                scheduler_node = graph.scheduler.nodes[0]
9355                DataTypePropagation.propagate_scheduler_node(scheduler_node)
9356                root_graph = scheduler_node._body.root_block.graph
9357                for node in root_graph.nodes:
9358                    if node.op == "placeholder":
9359                        self.assertEqual(get_data_type(node), None)
9360                    elif node.target in boolean_ops():
9361                        self.assertEqual(get_data_type(node), torch.bool)
9362                    elif node.target in (
9363                        "constant",
9364                        "to_dtype",
9365                        "index_expr",
9366                    ):
9367                        self.assertEqual(get_data_type(node), node.args[-1])
9368                    elif node.target in (
9369                        "get_index",
9370                        "index_expr",
9371                    ):
9372                        self.assertEqual(get_data_type(node), torch.int64)
9373                    elif node.target in (
9374                        "load",
9375                        "store",
9376                    ):
9377                        self.assertEqual(
9378                            get_data_type(node), V.graph.get_dtype(node.args[1])
9379                        )
9380                    elif node.target == "reduction":
9381                        _, _, dtype, _, _, _, _ = node.args
9382                        self.assertEqual(get_data_type(node), dtype)
9383                    elif node.target.startswith("masked_subblock"):
9384                        """
9385                        masked_subblocks:
9386                        opcode       name       target     args                        kwargs
9387                        -----------  ---------  ---------  --------------------------  --------
9388                        placeholder  ops        ops        ()                          {}
9389                        call_module  get_index  get_index  ('index2',)                 {}
9390                        call_method  load       load       (ops, 'arg0_1', get_index)  {}
9391                        call_method  to_dtype   to_dtype   (ops, load, torch.float32)  {}
9392                        output       output     output     (to_dtype,)                 {}
9393                        """
9394                        self.assertEqual(get_data_type(node), torch.float)
9395                    elif node.target == "and_":
9396                        """
9397                        and_'s input is boolean_ops:
9398                        -----------  ---------  ---------  --------------------------  --------
9399                        call_method  and__22           and_              (ops, ge_15, lt_15)
9400                        -----------  ---------  ---------  --------------------------  --------
9401                        """
9402                        self.assertEqual(get_data_type(node), torch.bool)
9403                    elif node.target == "maximum":
9404                        """
9405                        maximum's input is maximum or masked_subblock:
9406                        -----------  ---------  ---------  --------------------------  --------
9407                        call_method  maximum_6         maximum           (ops, masked_subblock8, maximum_5)
9408                        -----------  ---------  ---------  --------------------------  --------
9409                        """
9410                        self.assertEqual(get_data_type(node), torch.float)
9411                    elif node.target == "output":
9412                        self.assertEqual(get_data_type(node), torch.bfloat16)
9413
9414    # Calling div only torch.SymInt arguments is not yet supported.
9415    # To support this behavior, we need to allow const-propping tensors that store symint data.
9416    # For now, dynamo will explicitly graph break when it encounters user code with this behavior.
9417    @expectedFailureCodegenDynamic
9418    def test_AllenaiLongformerBase_repro(self):
9419        def fn(query, scores, window_overlap):
9420            batch_size, seq_len, num_heads, _ = query.size()
9421            chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
9422            diagonal_attention_scores = scores.new_zeros(
9423                (
9424                    batch_size * num_heads,
9425                    chunks_count + 1,
9426                    window_overlap,
9427                    window_overlap * 2 + 1,
9428                )
9429            )
9430            diagonal_attention_scores[:, :-1, :, window_overlap:] = scores[
9431                :, :, :window_overlap, : window_overlap + 1
9432            ]
9433            input_tensor = diagonal_attention_scores.view(
9434                batch_size, num_heads, seq_len, 2 * window_overlap + 1
9435            ).transpose(2, 1)
9436            beginning_input = input_tensor[:, :window_overlap, :, : window_overlap + 1]
9437            input_tensor[:, :window_overlap, :, : window_overlap + 1] = torch.full_like(
9438                beginning_input, -float("inf")
9439            )
9440            return input_tensor
9441
9442        args = [
9443            ((4, 1024, 12, 64), (768, 3072, 64, 1)),
9444            ((48, 3, 512, 513), (787968, 262656, 513, 1)),
9445        ]
9446        args = [rand_strided(sh, st) for (sh, st) in args]
9447        args.append(256)
9448
9449        if self.device == "cpu":
9450            opt_fn = torch._dynamo.optimize("inductor")(fn)
9451            _, code = run_and_get_cpp_code(opt_fn, *args)
9452            print(code)
9453            FileCheck().check_count(
9454                "static_cast<int32_t>(256)",
9455                1,
9456                exactly=True,
9457            ).run(code)
9458
9459        self.common(fn, args)
9460
9461    def test_cumsum_pattern_matcher_issue(self):
9462        def fn(input_ids) -> torch.Tensor:
9463            input_shape = input_ids.size()
9464            input_ids = input_ids.view(-1, input_shape[-1])
9465            batch_size, seq_length = input_shape
9466            past_key_values_length = 0
9467            mask_seq_length = past_key_values_length + seq_length
9468            attention_mask = torch.ones(
9469                batch_size, mask_seq_length, device=input_ids.device
9470            )
9471            attention_mask = attention_mask.long()
9472            return torch.cumsum(attention_mask, dim=1)
9473
9474        x = torch.randn(2, 2)
9475        self.common(fn, (x,), atol=0, rtol=0)
9476
9477    @staticmethod
9478    def _check_resize_common(
9479        self, fn, x, size_or_y, memory_format, inplace, deterministic
9480    ):
9481        x_ref_arg = x.clone()
9482        x_opt_arg = x.clone()
9483        x_numel = x.numel()
9484        torch._dynamo.reset_code_caches()
9485        opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn)
9486        correct = fn(x_ref_arg, size_or_y, memory_format)
9487        actual = opt_fn(x_opt_arg, size_or_y, memory_format)
9488
9489        def get_numel(size_or_y):
9490            if isinstance(size_or_y, torch.Tensor):
9491                return size_or_y.numel()
9492            else:
9493                # assume shape
9494                return functools.reduce(lambda x, y: x * y, size_or_y, 1)
9495
9496        if deterministic:
9497            nele_check = correct.numel()
9498        else:
9499            nele_check = min(x_numel, get_numel(size_or_y))
9500
9501        correct_values = correct.as_strided((nele_check,), (1,))
9502        actual_values = actual.as_strided((nele_check,), (1,))
9503        self.assertTrue(same(correct_values, actual_values, equal_nan=deterministic))
9504        correct_strides = correct.stride()
9505        actual_strides = actual.stride()
9506        self.assertEqual(correct_strides, actual_strides)
9507
9508    @staticmethod
9509    def _cases_resize_common():
9510        sizes = [
9511            ((2,), (1, 3, 2, 3)),
9512            ((100,), (1, 3, 2, 3)),
9513            ((1, 3, 2, 3), (1, 3, 2, 3)),
9514            ((2,), (1, 3, 2, 3, 1)),
9515            ((100,), (1, 3, 2, 3, 1)),
9516            ((1, 3, 2, 3, 1), (1, 3, 2, 3, 1)),
9517            ((2, 0, 1), (2, 2)),
9518        ]
9519        for x_size, y_size in sizes:
9520            memory_formats = [torch.contiguous_format]
9521            if len(y_size) == 4:
9522                memory_formats.append(torch.channels_last)
9523            if len(y_size) == 5:
9524                memory_formats.append(torch.channels_last_3d)
9525            for memory_format in memory_formats:
9526                x = torch.randn(*x_size)
9527                yield x, y_size, memory_format
9528                # check some non-contiguous tensors
9529                if x.numel() == 100:
9530                    x_strided = x[::2].reshape(25, 2).transpose(0, 1)
9531                    yield x_strided, y_size, memory_format
9532
9533    def test_resize(self):
9534        def fn(x, size, memory_format):
9535            # NOTE: Tensor.resize() =/= aten::resize()
9536            return torch.ops.aten.resize(x, size, memory_format=memory_format)
9537
9538        for deterministic in [True, False]:
9539            with DeterministicGuard(
9540                deterministic, fill_uninitialized_memory=deterministic
9541            ):
9542                for x, y_size, memory_format in CommonTemplate._cases_resize_common():
9543                    CommonTemplate._check_resize_common(
9544                        self,
9545                        fn,
9546                        x,
9547                        y_size,
9548                        memory_format,
9549                        inplace=False,
9550                        deterministic=deterministic,
9551                    )
9552
9553    @staticmethod
9554    def _cases_resize_as_common():
9555        for x, y_size, memory_format in CommonTemplate._cases_resize_common():
9556            # each sizes /memory_format combintation tested in 2 ways:
9557            # 1. y is contiguous fn gets memory_format kwargs
9558            # 2. y has memory_format contiguity and fn gets preserve kwarg
9559            # 3. y has some other strides (not contiguous or channels last) and fn gets preserve
9560            yield x, torch.randn(*y_size), memory_format
9561            yield x, torch.randn(*y_size).contiguous(
9562                memory_format=memory_format
9563            ), torch.preserve_format
9564            yield x, torch.randn(*y_size).permute(
9565                tuple(reversed(range(len(y_size))))
9566            ), torch.preserve_format
9567
9568    def test_resize_as(self):
9569        def fn(x, y, memory_format):
9570            return torch.ops.aten.resize_as(x, y, memory_format=memory_format)
9571
9572        for deterministic in [True, False]:
9573            with DeterministicGuard(
9574                deterministic, fill_uninitialized_memory=deterministic
9575            ):
9576                for x, y, memory_format in CommonTemplate._cases_resize_as_common():
9577                    CommonTemplate._check_resize_common(
9578                        self,
9579                        fn,
9580                        x,
9581                        y,
9582                        memory_format,
9583                        inplace=False,
9584                        deterministic=deterministic,
9585                    )
9586
9587    def test_inplace_resize_as(self):
9588        def fn(x, y):
9589            x.resize_as_(y)
9590            return x
9591
9592        x = torch.randn(2, 3)
9593        y = torch.randn(200, 300)
9594        x_clone = x.clone()
9595        opt_fn = torch._dynamo.optimize("inductor")(fn)
9596        same(fn(x, y), opt_fn(x_clone, y))
9597
9598    def test_erfc(self):
9599        def fn(x):
9600            return torch.erfc(x)
9601
9602        self.common(fn, (torch.randn(8, 8),))
9603
9604    def test_erfinv(self):
9605        def fn(x):
9606            return torch.erfinv(x)
9607
9608        # domain for erfinv is (-1, 1)
9609        x = torch.empty(8, 8).uniform_(-1, 1)
9610        self.common(fn, (x,))
9611
9612    def test_uint(self):
9613        def fn(z):
9614            x = torch.tensor(5, device=z.device, dtype=torch.uint8)
9615            y = torch.neg(x)
9616            return x < y
9617
9618        self.common(fn, (torch.randn(26),))
9619
9620    def test_scaled_dot_product_attention(self):
9621        if self.device == "cuda" and not PLATFORM_SUPPORTS_FLASH_ATTENTION:
9622            raise unittest.SkipTest("Can't run flash attention on this platform")
9623        if self.device == "cuda" and TEST_WITH_ROCM:
9624            raise unittest.SkipTest(
9625                "Flash attention support is incomplete on this platform"
9626            )
9627
9628        def fn(q, k, v):
9629            return torch.nn.functional.scaled_dot_product_attention(
9630                q.transpose(1, 2).contiguous(),
9631                k.transpose(1, 2),
9632                v.transpose(1, 2),
9633                scale=0.125,
9634            )[:2]
9635
9636        self.common(
9637            fn,
9638            (
9639                torch.randn(4, 2, 4, 2),
9640                torch.randn(4, 2, 4, 2),
9641                torch.randn(4, 2, 4, 2),
9642            ),
9643            atol=2e-4,  # to pass lowp check on GPU
9644            rtol=1e-2,  # to pass lowp check on GPU
9645        )
9646
9647    @skipIfRocm
9648    @expectedFailureXPU
9649    def test_scaled_dot_product_efficient_attention(self):
9650        if self.device == "cpu":
9651            raise unittest.SkipTest(f"requires {GPU_TYPE}")
9652
9653        # The first two values should be the same, attention output
9654        # and logsumexp since dropout is not being set
9655        def fn(q, k, v, attn_bias, compute_log_sumexp):
9656            return aten._scaled_dot_product_efficient_attention(
9657                q, k, v, attn_bias, compute_log_sumexp
9658            )[:2]
9659
9660        self.common(
9661            fn,
9662            (
9663                torch.randn(4, 4, 36, 36),
9664                torch.randn(4, 4, 36, 36),
9665                torch.randn(4, 4, 36, 36),
9666                torch.randn(4, 4, 36, 36),
9667                False,
9668            ),
9669            check_lowp=False,
9670        )
9671
9672    def test_fft_real_input(self):
9673        def fn(x):
9674            return torch.fft.fftn(x)
9675
9676        self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False)
9677
9678    def test_fft_real_input_real_output(self):
9679        def fn(x):
9680            return torch.fft.fftn(x).real
9681
9682        self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False)
9683
9684    def test_bucketize(self):
9685        def fn(input, boundaries, out_int32, right):
9686            return torch.bucketize(input, boundaries, out_int32=out_int32, right=right)
9687
9688        input = torch.rand((64, 64)) * 2 - 1
9689        boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])
9690
9691        for out_int32 in [True, False]:
9692            for right in [True, False]:
9693                out_int32 = True
9694                right = False
9695                self.common(fn, (input, boundaries, out_int32, right), check_lowp=False)
9696
9697    def test_bucketize_default_kwargs(self):
9698        def fn(input, offsets):
9699            return torch.bucketize(input, offsets)
9700
9701        input = torch.tensor(
9702            [-1.0, -0.9, -0.8, -0.5, 0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 0.9, 0.91]
9703        )
9704        offsets = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])
9705
9706        self.common(fn, (input, offsets), check_lowp=False)
9707
9708    def test_bucketize_int(self):
9709        def fn(input, offsets, out_int32, right):
9710            return torch.bucketize(input, offsets, out_int32=out_int32, right=right)
9711
9712        input = torch.randint(0, 102, (64, 64))
9713        offsets = torch.arange(10, dtype=torch.int32) ** 2 + 1
9714
9715        for out_int32 in [True, False]:
9716            for right in [True, False]:
9717                self.common(fn, (input, offsets, out_int32, right), check_lowp=False)
9718
9719    @patch.object(config.triton, "autotune_pointwise", True)
9720    def test_bucketize_add_autotune(self):
9721        # Causes a @pointwise(size_hints) where size_hints is 2D
9722
9723        def fn(input, offsets, add_value):
9724            return torch.bucketize(input, offsets) + add_value
9725
9726        input = torch.rand((16, 16, 64, 64))
9727        boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])
9728        add_value = torch.randint(0, 1024, (16, 16, 64, 64)).to(
9729            memory_format=torch.channels_last
9730        )
9731
9732        self.common(fn, (input, boundaries, add_value), check_lowp=False)
9733
9734        assertGeneratedKernelCountEqual(self, 1)
9735
9736    def test_bucketize_computed_offsets(self):
9737        def fn(inp, offsets):
9738            return torch.bucketize(inp, offsets + 0.01)
9739
9740        inp = torch.tensor(
9741            [-1.0, -0.9, -0.8, -0.5, 0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 0.9, 0.91]
9742        )
9743        offsets = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9]) - 0.01
9744
9745        self.common(fn, (inp, offsets), check_lowp=False)
9746
9747    @requires_gpu()
9748    @config.patch(assume_aligned_inputs=False)
9749    def test_config_option_dont_assume_alignment(self):
9750        def fn(x: torch.Tensor) -> torch.Tensor:
9751            return x.sin() + x.cos()
9752
9753        # Inductor specializes on the (unguarded) alignment of the initial input.
9754        # Make sure that for different configurations, nothing breaks.
9755        for offset in (0, 1, 2, 3, 4):
9756            base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
9757            inp = torch.as_strided(base, (64, 64), (64, 1), offset)
9758            torch._dynamo.reset()
9759            fn_c = torch.compile(fn)
9760
9761            ref = fn(inp)
9762            res = fn_c(inp)
9763            self.assertEqual(ref, res)
9764
9765            for offset2 in (0, 1, 2, 3, 4):
9766                base2 = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
9767                inp2 = torch.as_strided(base, (64, 64), (64, 1), offset2)
9768                ref2 = fn(inp2)
9769                res2 = fn_c(inp2)
9770                self.assertEqual(ref2, res2)
9771
9772    @requires_gpu()
9773    @config.patch(assume_aligned_inputs=False)
9774    def test_config_option_dont_assume_alignment_recompiles(self):
9775        # Inputs:
9776        #  1. (32, 32) shape
9777        #  2. (64, 64) shape -> causes a recompile
9778        #  3. (64, 64) shape with different storage offset -> should NOT cause a recompile
9779        failed_guards = []
9780
9781        def fail(guard):
9782            nonlocal failed_guards
9783            failed_guards.append(guard)
9784
9785        def fn(x: torch.Tensor) -> torch.Tensor:
9786            return x.sin() + x.cos()
9787
9788        base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
9789
9790        inp1 = torch.as_strided(base, (32, 32), (32, 1), 4)
9791        inp2 = torch.as_strided(base, (64, 64), (64, 1), 4)
9792        inp3 = torch.as_strided(base, (64, 64), (64, 1), 5)
9793
9794        torch._dynamo.reset()
9795
9796        fn_c = torch._dynamo.optimize("inductor", guard_fail_fn=fail)(fn)
9797
9798        ref1 = fn(inp1)
9799        res1 = fn_c(inp1)
9800        self.assertEqual(ref1, res1)
9801        self.assertEqual(0, len(failed_guards))
9802
9803        ref2 = fn(inp2)
9804        res2 = fn_c(inp2)
9805        self.assertEqual(ref2, res2)
9806        # if dynamic shapes isn't already turned on, we might have a guard failure as we turn
9807        # on dynamic shapes
9808        self.assertLessEqual(len(failed_guards), 1)
9809        failed_guard_count_iteration_2 = len(failed_guards)
9810
9811        failed_guards = []
9812        ref3 = fn(inp3)
9813        res3 = fn_c(inp3)
9814        self.assertEqual(ref3, res3)
9815        # we might still have the dynamics shapes failure, but offset change shouldn't be guarded on
9816        # see Note: [Input Alignment handling in Inductor]
9817        self.assertLessEqual(len(failed_guards), failed_guard_count_iteration_2)
9818
9819    @requires_gpu()
9820    @config.patch(assume_aligned_inputs=False)
9821    def test_config_option_dont_assume_alignment_cudagraphs(self):
9822        def fn(x):
9823            return x.cos() * x.sin()
9824
9825        fn_c = torch.compile(fn, mode="reduce-overhead", dynamic=True)
9826
9827        for size, stride, offset in (
9828            ((32, 32), (32, 1), 4),
9829            ((48, 48), (48, 1), 4),
9830            ((64, 64), (64, 1), 5),
9831        ):
9832            torch.manual_seed(42)
9833            base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
9834            torch.manual_seed(42)
9835            base_ref = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
9836
9837            inp = torch.as_strided(base, size, stride, offset)
9838            inp_ref = torch.as_strided(base_ref, size, stride, offset)
9839
9840            inp.requires_grad_(True)
9841            inp_ref.requires_grad_(True)
9842
9843            res = fn_c(inp)
9844            ref = fn(inp_ref)
9845            self.assertEqual(ref, res)
9846
9847            res.sum().backward()
9848            ref.sum().backward()
9849            self.assertEqual(base.grad, base_ref.grad)
9850
9851    @config.patch(implicit_fallbacks=True)
9852    def test_custom_op_1(self):
9853        import torch.library
9854
9855        def foo_cpu(x):
9856            return 3 * x
9857
9858        def foo_cuda(x):
9859            return 3 * x
9860
9861        def foo_xpu(x):
9862            return 3 * x
9863
9864        def foo_meta(x):
9865            return torch.empty_like(x)
9866
9867        define_custom_op_for_test("foo", foo_cpu, foo_cuda, foo_xpu, foo_meta)
9868
9869        def fn(x):
9870            a = torch.nn.functional.relu(x)
9871            b = torch.ops.test.foo(a)
9872            c = torch.cos(b)
9873            return c
9874
9875        self.common(fn, (torch.randn((16, 32)),), check_lowp=False)
9876
9877    @config.patch(implicit_fallbacks=True)
9878    def test_custom_op_2(self):
9879        import torch.library
9880
9881        def foo_cpu(x, scale: float):
9882            return scale * x, torch.cos(x)
9883
9884        def foo_cuda(x, scale: float):
9885            return scale * x, torch.cos(x)
9886
9887        def foo_xpu(x, scale: float):
9888            return scale * x, torch.cos(x)
9889
9890        def foo_meta(x, scale: float):
9891            return torch.empty_like(x), torch.empty_like(x)
9892
9893        define_custom_op_2_for_test("foo2", foo_cpu, foo_cuda, foo_xpu, foo_meta)
9894
9895        def fn(x, scale: float):
9896            a = torch.nn.functional.relu(x)
9897            return torch.ops.test.foo2(a, scale)
9898
9899        self.common(fn, (torch.randn((16, 32)), 2.0), check_lowp=False)
9900
9901    @config.patch(implicit_fallbacks=True)
9902    def test_custom_op_3(self):
9903        import torch.library
9904
9905        def foo_cpu(x):
9906            result = torch.zeros_like(x[0])
9907            for t in x:
9908                result += t
9909            return result
9910
9911        def foo_cuda(x):
9912            result = torch.zeros_like(x[0])
9913            for t in x:
9914                result += t
9915            return result
9916
9917        def foo_xpu(x):
9918            result = torch.zeros_like(x[0])
9919            for t in x:
9920                result += t
9921            return result
9922
9923        def foo_meta(x):
9924            return torch.empty_like(x[0])
9925
9926        define_custom_op_3_for_test("foo3", foo_cpu, foo_cuda, foo_xpu, foo_meta)
9927
9928        def fn(x):
9929            return torch.ops.test.foo3(x)
9930
9931        self.common(
9932            fn,
9933            ([torch.randn((16, 32)), torch.randn((16, 32)), torch.randn((16, 32))],),
9934            check_lowp=False,
9935        )
9936
9937    @requires_gpu()
9938    @torch._inductor.config.patch("layout_optimization", True)
9939    @torch._inductor.config.patch("keep_output_stride", False)
9940    @config.patch(implicit_fallbacks=True)
9941    def test_custom_op_fixed_layout_sequential(self):
9942        import torch.library
9943
9944        mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).to(device=GPU_TYPE)
9945        inp = torch.rand(2, 3, 128, 128, device=GPU_TYPE)
9946        expected_stride = mod(inp).stride()
9947
9948        def bar_cpu(x):
9949            self.assertEqual(x.stride(), expected_stride)
9950            return x.clone()
9951
9952        def bar_cuda(x):
9953            self.assertEqual(x.stride(), expected_stride)
9954            return x.clone()
9955
9956        def bar_xpu(x):
9957            self.assertEqual(x.stride(), expected_stride)
9958            return x.clone()
9959
9960        def bar_meta(x):
9961            return torch.empty_like(x)
9962
9963        define_custom_op_for_test(
9964            "bar",
9965            bar_cpu,
9966            bar_cuda,
9967            bar_xpu,
9968            bar_meta,
9969            tags=[torch._C.Tag.needs_fixed_stride_order],
9970        )
9971
9972        def fn(x):
9973            z = mod(x)
9974            output = torch.ops.test.bar(z)
9975            return output
9976
9977        with torch.no_grad():
9978            # With keep_output_stride False, inductor would normally have different layout from eager execution
9979            # But because our custom op needs fixed layout, the assertions in the custom op will pass
9980            self.common(fn, (inp,), check_lowp=False)
9981
9982    @config.patch(implicit_fallbacks=True)
9983    def test_mutable_custom_op_fixed_layout(self):
9984        with torch.library._scoped_library("mylib", "DEF") as lib:
9985            lib.define(
9986                "copy_(Tensor(a!) dst, Tensor src) -> ()",
9987                tags=torch.Tag.needs_fixed_stride_order,
9988            )
9989
9990            @torch.library.impl(lib, "copy_", "Meta")
9991            def _(dst, src):
9992                return None
9993
9994            @torch.library.impl(lib, "copy_", "CompositeExplicitAutograd")
9995            def _(dst, src):
9996                dst.copy_(src)
9997
9998            def f(x):
9999                full_default_3 = torch.full([3], 7.0, device="cpu")
10000                chunk_cat_default_1 = torch.ops.mylib.copy_.default(full_default_3, x)
10001                mul_out = torch.mul(full_default_3, full_default_3)
10002                return mul_out
10003
10004            x = torch.arange(3, dtype=torch.float, device="cpu")
10005            eager_out = f(x)
10006
10007            compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
10008            compiled_inductor_out = compiled_inductor_f(x)
10009            self.assertEqual(compiled_inductor_out, eager_out)
10010
10011    @requires_gpu()
10012    @config.patch(implicit_fallbacks=True)
10013    def test_custom_op_fixed_layout_channels_last(self):
10014        class Block(nn.Module):
10015            def __init__(
10016                self,
10017            ):
10018                super().__init__()
10019
10020                self.in_layers = nn.Sequential(
10021                    nn.Dropout(p=0.1),
10022                )
10023
10024            def helper(self, x):
10025                out = F.gelu(x)
10026                out = self.in_layers(out)
10027                return out
10028
10029            def forward(self, x):
10030                out = self.helper(x)
10031                out = torch.ops.test.baz(out)
10032                return out
10033
10034        model = Block()
10035        model = model.to(GPU_TYPE).to(memory_format=torch.channels_last)
10036        input_t = torch.randn([1, 320, 128, 128], dtype=torch.float32, device=GPU_TYPE)
10037        input_t = input_t.to(memory_format=torch.channels_last)
10038        expected_strides = model.helper(input_t).stride()
10039
10040        def baz_cpu(x):
10041            self.assertEqual(expected_strides, x.stride())
10042            return x.clone()
10043
10044        def baz_cuda(x):
10045            self.assertEqual(expected_strides, x.stride())
10046            return x.clone()
10047
10048        def baz_xpu(x):
10049            self.assertEqual(expected_strides, x.stride())
10050            return x.clone()
10051
10052        def baz_meta(x):
10053            return torch.empty_like(x)
10054
10055        define_custom_op_for_test(
10056            "baz",
10057            baz_cpu,
10058            baz_cuda,
10059            baz_xpu,
10060            baz_meta,
10061            tags=[torch._C.Tag.needs_fixed_stride_order],
10062        )
10063
10064        with torch.no_grad():
10065            net = torch.compile(model)
10066            out = net(input_t)
10067
10068    def test_buffer_use_after_remove(self):
10069        # https://github.com/pytorch/pytorch/issues/102857
10070
10071        def rotvec_to_rotmat(rotvec) -> torch.Tensor:
10072            """Simplified rotvec to rotmat code from RoMa
10073            (https://github.com/naver/roma/blob/06e4b0cdc1c802a60a012bb19c581d6600c63358/roma/mappings.py#L371)
10074            """
10075            theta = torch.norm(rotvec, dim=-1)
10076            axis = rotvec / theta[..., None]
10077            kx, ky, kz = axis[:, 0], axis[:, 1], axis[:, 2]
10078            sin_theta = torch.sin(theta)
10079            cos_theta = torch.cos(theta)
10080            one_minus_cos_theta = 1 - cos_theta
10081            xs = kx * sin_theta
10082            ys = ky * sin_theta
10083            zs = kz * sin_theta
10084            xyc = kx * ky * one_minus_cos_theta
10085            xzc = kx * kz * one_minus_cos_theta
10086            yzc = ky * kz * one_minus_cos_theta
10087            xxc = kx**2 * one_minus_cos_theta
10088            yyc = ky**2 * one_minus_cos_theta
10089            zzc = kz**2 * one_minus_cos_theta
10090            R_rodrigues = torch.stack(
10091                [
10092                    1 - yyc - zzc,
10093                    xyc - zs,
10094                    xzc + ys,
10095                    xyc + zs,
10096                    1 - xxc - zzc,
10097                    -xs + yzc,
10098                    xzc - ys,
10099                    xs + yzc,
10100                    1 - xxc - yyc,
10101                ],
10102                dim=-1,
10103            ).reshape(-1, 3, 3)
10104            R = R_rodrigues
10105            return R
10106
10107        def f(coord, rot, trans):
10108            rot_mat = rotvec_to_rotmat(rot)
10109            coord = torch.einsum("...ij,...bj->...bi", rot_mat, coord) + trans
10110            return coord.sum()
10111
10112        foo_c = torch.compile(f, dynamic=True)
10113
10114        def run(fn):
10115            coord = torch.ones((2, 3), device=self.device)
10116            rot = nn.Parameter(torch.ones((2, 3), device=self.device))
10117            trans = nn.Parameter(torch.ones((2, 3), device=self.device))
10118
10119            U = fn(coord, rot, trans)
10120            U.backward()
10121
10122            return U, rot, trans
10123
10124        U_e, rot_e, trans_e = run(f)
10125        U, rot, trans = run(foo_c)
10126
10127        self.assertEqual(U, U_e)
10128        self.assertEqual(rot.grad, rot_e.grad)
10129        self.assertEqual(trans.grad, trans_e.grad)
10130
10131    @config.patch({"fx_graph_cache": False})
10132    def test_inner_fn_str_and_stride(self):
10133        def f(x):
10134            x = x + 1
10135            x = test_operators.realize(x)
10136            x = x * 2
10137            x = test_operators.realize(x)
10138            return x
10139
10140        x = torch.rand(3, 2, device=self.device).t()
10141        ref = f(x)
10142        called = False
10143
10144        def hook_fn(scheduler, nodes):
10145            nonlocal called
10146            called = True
10147
10148            if self.device != "cpu":
10149                self.assertEqual(len(nodes), 3)
10150                _, mul_buf, _ = nodes
10151                self.assertTrue(
10152                    all(
10153                        V.graph.sizevars.size_hints(buf.get_stride()) == (1, 2)
10154                        for buf in nodes
10155                    )
10156                )
10157                # before the fix, the wrong index expression
10158                # 'i1 + 3 * i0' is cached.
10159                self.assertTrue(
10160                    "i0 + 2 * i1" in mul_buf.data.inner_fn_str()
10161                    or "i0 + i1 * s1" in mul_buf.data.inner_fn_str()
10162                )
10163
10164        with add_scheduler_init_hook(hook_fn):
10165            actual = torch.compile(f, fullgraph=True)(x)
10166        self.assertEqual(ref, actual)
10167        self.assertTrue(called)
10168
10169    def test_mutations_loop_fusion(self):
10170        def fn(tensor, index, source):
10171            out = tensor.index_add(0, index, source, alpha=2.0) / 2
10172            return out
10173
10174        device = "cpu"
10175        tensor = torch.rand((1,), dtype=torch.double, device=device)
10176        index = torch.tensor([0], dtype=torch.long, device=device)
10177        source = torch.rand((1,), dtype=torch.double, device=device)
10178        self.common(
10179            fn,
10180            (
10181                tensor,
10182                index,
10183                source,
10184            ),
10185        )
10186
10187    @config.patch(
10188        "triton.autotune_pointwise", True
10189    )  # needed to introduce config that exceed max shared memory usage
10190    @serialTest()
10191    def test_large_block_sizes(self):
10192        """
10193        Inductor will try triton configs like x = 64 and y = 1024 which will
10194        result in out of shared memory if dtype is fp32.
10195
10196        Currently inductor will skip such bad configs and pick the best one
10197        from the remaining configs.
10198        """
10199        if not _has_sufficient_memory(self.device, 3 * 2**24 * 65 * 4):
10200            raise unittest.SkipTest("insufficient memory")
10201
10202        @torch.compile
10203        def fn(x, y):
10204            return x.t() + y
10205
10206        # Use shape (2**24, 65) rather than (2**24, 128) potentially avoid OOM in
10207        # CI while still keep the same up-rounded size-hints.
10208        a = torch.randn(2**24, 65, device=self.device)
10209        b = torch.randn(65, 2**24, device=self.device)
10210        fn(a, b)
10211
10212    # Skipped on ROCm until https://github.com/ROCm/triton/issues/443 resolved
10213    @skipIfRocm
10214    def test_fuse_large_params(self):
10215        def pt2_optimizer_step(optimizer):
10216            @torch.compile()
10217            def f():
10218                optimizer.step()
10219
10220            f()
10221
10222        params = [
10223            torch.rand(10, 10, dtype=torch.float32, device=self.device)
10224            for _ in range(194)
10225        ]
10226        for p in params:
10227            p.grad = torch.rand_like(p)
10228
10229        o = torch.optim.AdamW(params)
10230        pt2_optimizer_step(o)
10231
10232    def test_adaptive_avg_pool1d_argmax(self):
10233        # https://github.com/pytorch/pytorch/issues/113013
10234        def fn(x):
10235            x = torch.adaptive_avg_pool1d(input=x, output_size=2)
10236            x = torch.argmax(input=x)
10237            return x
10238
10239        x = torch.rand([4, 4, 3], dtype=torch.float64)
10240        self.common(fn, (x,))
10241
10242    def test_float16_to_int16(self):
10243        def fn(x):
10244            x_view = x.view(dtype=torch.int16)
10245            return x_view.mul(2)
10246
10247        x = torch.ones(4, dtype=torch.float16, device=self.device)
10248        ref = fn(x)
10249        actual = torch.compile(fn)(x)
10250        self.assertEqual(ref, actual)
10251
10252    @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
10253    def test_bfloat16_to_int16(self):
10254        def fn(a, b):
10255            x = a + b
10256            x_view = x.view(dtype=torch.int16)
10257            return x_view.mul(2)
10258
10259        a = torch.ones(4, dtype=torch.bfloat16, device=self.device)
10260        b = torch.ones(4, dtype=torch.bfloat16, device=self.device)
10261        ref = fn(a, b)
10262        actual = torch.compile(fn)(a, b)
10263        self.assertEqual(ref, actual)
10264
10265    def test_float32_to_int32(self):
10266        def fn(a, b):
10267            x = a + b
10268            x_view = x.view(dtype=torch.int32)
10269            return x_view.mul(2)
10270
10271        a = torch.ones(4, dtype=torch.float32, device=self.device)
10272        b = torch.ones(4, dtype=torch.float32, device=self.device)
10273        ref = fn(a, b)
10274        actual = torch.compile(fn)(a, b)
10275        self.assertEqual(ref, actual)
10276
10277    def test_randint_int64_mod(self):
10278        # This used to not compile due to a wrong return type of randint64_cpu
10279        # See https://github.com/pytorch/pytorch/issues/117435
10280        def fn(n):
10281            return (
10282                torch.randint(
10283                    low=-5, high=5, size=(n,), dtype=torch.int64, device=self.device
10284                )
10285                % 10
10286            )
10287
10288        res = torch.compile(fn)(20)
10289        self.assertTrue(torch.all((0 <= res) & (res < 10)).item())
10290
10291    @torch._inductor.config.patch(force_shape_pad=True)
10292    def test_should_pad_bench_for_bmm(self):
10293        B = 2
10294        M = 1024
10295        N = 1024
10296        K = 1024 + 1  # a size that requires padding
10297
10298        mat1 = torch.rand(B, M, K, device=self.device)
10299        mat2 = torch.rand(B, K, N, device=self.device)
10300
10301        should_pad = pad_mm.should_pad_bench(None, mat1, mat2, torch.ops.aten.bmm)
10302
10303        self.assertTrue(should_pad)
10304
10305    @parametrize(
10306        "name, op",
10307        [
10308            subtest((name, getattr(torch.special, name)), name=name)
10309            for name in torch.special.__all__
10310            if name not in {"softmax", "log_softmax", "logsumexp"}
10311        ],
10312    )
10313    def test_pointwise(self, name, op):
10314        dtype = torch.float32
10315        check_lowp = True
10316        if self.device == GPU_TYPE and name in {
10317            "airy_ai",
10318            "bessel_i0",
10319            "bessel_i1",
10320            "bessel_j0",
10321            "bessel_j1",
10322            "bessel_y0",
10323            "bessel_y1",
10324            "erfcx",
10325            "gammainc",
10326            "gammaincc",
10327            "i1",
10328            "i1e",
10329            "modified_bessel_i0",
10330            "modified_bessel_i1",
10331            "modified_bessel_k0",
10332            "modified_bessel_k1",
10333            "ndtri",
10334            "scaled_modified_bessel_k0",
10335            "scaled_modified_bessel_k1",
10336            "spherical_bessel_j0",
10337            "zeta",
10338            "chebyshev_polynomial_t",
10339            "chebyshev_polynomial_v",
10340            "chebyshev_polynomial_u",
10341            "chebyshev_polynomial_w",
10342            "legendre_polynomial_p",
10343            "shifted_chebyshev_polynomial_t",
10344            "shifted_chebyshev_polynomial_u",
10345            "shifted_chebyshev_polynomial_v",
10346            "shifted_chebyshev_polynomial_w",
10347            "hermite_polynomial_h",
10348            "hermite_polynomial_he",
10349            "laguerre_polynomial_l",
10350        }:
10351            # <func>_cuda not implemented for Half
10352            check_lowp = False
10353
10354        if name in {"gammainc", "gammaincc"}:
10355            args = (
10356                torch.randn(8, 8, dtype=dtype, device=self.device),
10357                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
10358            )
10359
10360            def fn(x, y):
10361                return op(x, y)
10362
10363        elif name in {"xlog1py", "xlogy", "zeta"}:
10364            args = (
10365                torch.randn(8, 8, dtype=dtype, device=self.device),
10366                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
10367            )
10368
10369            def fn(x, y):
10370                return op(x, y)
10371
10372        elif name == "multigammaln":
10373            args = (
10374                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
10375                2,
10376            )
10377
10378            def fn(x, p):
10379                return op(x, p)
10380
10381        elif name == "polygamma":
10382            args = (
10383                1,
10384                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 10),
10385            )
10386
10387            def fn(n, x):
10388                return op(n, x)
10389
10390        elif "_polynomial_" in name:
10391            args = (
10392                torch.randn(8, 8, dtype=dtype, device=self.device),
10393                2,
10394            )
10395
10396            def fn(x, n):
10397                return op(x, n)
10398
10399        else:
10400            args = (torch.randn(8, 8, dtype=dtype, device=self.device),)
10401
10402            def fn(x):
10403                return op(x)
10404
10405        self.common(fn, args, check_lowp=check_lowp)
10406
10407    # codegen test fails with no dynamic for loop in dynamic shape tests
10408    @expectedFailureCodegenDynamic
10409    def test_view_uint8_through_differing_bitwidths(self):
10410        # https://github.com/pytorch/pytorch/issues/120998
10411        def fn(x, view_dtype):
10412            return x.view(view_dtype).view(torch.uint8)
10413
10414        view_dtypes = [torch.int16, torch.int32, torch.int64]
10415        for dtype in view_dtypes:
10416            x = torch.randint(0, 2**4, [4096, 4096], dtype=torch.uint8)
10417            self.common(
10418                fn,
10419                (
10420                    x,
10421                    dtype,
10422                ),
10423            )
10424
10425    @torch._dynamo.config.patch(capture_scalar_outputs=True)
10426    def test_split_with_sizes_with_unbacked_symints(self):
10427        @torch.compile()
10428        def f(sz, x):
10429            s0, s1 = sz.tolist()
10430            r0, r1 = torch.ops.aten.split_with_sizes.default(x, [s0, s1])
10431            return torch.ops.aten.sort.default(r1)
10432
10433        N = 7312
10434        S0 = 420
10435        S1 = N - S0
10436
10437        result = f(torch.tensor([S0, S1]), torch.randn(N))
10438        self.assertTrue(len(result) == 2)
10439
10440        @torch.compile()
10441        def f2(x):
10442            y = torch.arange(x.item())
10443            return torch.ops.aten.split_with_sizes.default(y, [5, 5, 10])
10444
10445        result = f2(torch.tensor([20]))
10446        self.assertTrue(len(result) == 3)
10447
10448    @torch._dynamo.config.patch(capture_scalar_outputs=True)
10449    def test_split_with_unbacked_symints(self):
10450        # https://github.com/pytorch/pytorch/issues/122937
10451        @torch.compile()
10452        def f(x):
10453            y = torch.arange(x.item())
10454            return torch.split(y, [5, 5, 10])
10455
10456        result = f(torch.tensor([20]))
10457        self.assertTrue(len(result) == 3)
10458
10459    def test_complex_memory_overlap(self):
10460        t = rand_strided((8, 1500, 1), (1504, 1, 1), device=self.device)
10461        self.assertFalse(complex_memory_overlap(t))
10462
10463    def test_generate_rand_fp8(self):
10464        """
10465        PyTorch can not generate fp8 tensors with a normal distribution because of
10466        missing needed kernels.
10467
10468        We work around that in rand_strided by generating an fp16 tensor first and
10469        then do casting.
10470        """
10471        t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn)
10472        self.assertTrue(t.dtype is torch.float8_e4m3fn)
10473
10474    def test_large_grid(self):
10475        # https://github.com/pytorch/pytorch/issues/123210
10476        def fn(primals_5):
10477            view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
10478            primals_5 = None
10479            permute = torch.ops.aten.permute.default(view, [0, 2, 1])
10480            clone = torch.ops.aten.clone.default(
10481                permute, memory_format=torch.contiguous_format
10482            )
10483            return clone
10484
10485        s0 = 16777472
10486        s1 = 8
10487        compiled_fn = torch._dynamo.optimize()(fn)
10488        actual = compiled_fn(torch.ones(s0, s1))
10489        self.assertTrue((actual == 1).all())
10490
10491
10492@dataclasses.dataclass
10493class TestFailure:
10494    suffixes: Tuple[str]
10495    is_skip: bool = False
10496    __test__: bool = False
10497
10498
10499def copy_tests(
10500    my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
10501):  # noqa: B902
10502    for name, value in my_cls.__dict__.items():
10503        if name.startswith("test_"):
10504            # You cannot copy functions in Python, so we use closures here to
10505            # create objects with different ids. Otherwise, unittest.skip
10506            # would modify all methods sharing the same object id. Also, by
10507            # using a default argument, we create a copy instead of a
10508            # reference. Otherwise, we would lose access to the value.
10509
10510            @functools.wraps(value)
10511            def new_test(self, value=value):
10512                return value(self)
10513
10514            # Copy __dict__ which may contain test metadata
10515            new_test.__dict__ = copy.deepcopy(value.__dict__)
10516
10517            if xfail_prop is not None and hasattr(value, xfail_prop):
10518                new_test = unittest.expectedFailure(new_test)
10519
10520            tf = test_failures and test_failures.get(name)
10521            if tf is not None and suffix in tf.suffixes:
10522                skip_func = (
10523                    unittest.skip("Skipped!")
10524                    if tf.is_skip
10525                    else unittest.expectedFailure
10526                )
10527                new_test = skip_func(new_test)
10528
10529            setattr(other_cls, f"{name}_{suffix}", new_test)
10530
10531
10532if HAS_CPU:
10533
10534    class SweepInputsCpuTest(SweepInputs2, TestCase):
10535        gen = InputGen(10, "cpu")
10536
10537    SweepInputsCpuTest.populate()
10538
10539    class CpuTests(TestCase):
10540        common = check_model
10541        device = "cpu"
10542
10543    copy_tests(CommonTemplate, CpuTests, "cpu")
10544
10545if HAS_GPU and not TEST_WITH_ASAN:
10546
10547    class SweepInputsGPUTest(SweepInputs2, TestCase):
10548        gen = InputGen(10, GPU_TYPE)
10549
10550    SweepInputsGPUTest.populate()
10551
10552    class GPUTests(TestCase):
10553        common = check_model_gpu
10554        device = GPU_TYPE
10555
10556    copy_tests(CommonTemplate, GPUTests, GPU_TYPE)
10557
10558    class TritonCodeGenTests(TestCase):
10559        from torch._inductor.runtime.triton_heuristics import CachingAutotuner
10560
10561        device_type = GPU_TYPE
10562
10563        class NoOpCompilerBackend:
10564            def __init__(self):
10565                self.example_args = None
10566                self.model = None
10567
10568            def noop_backend(
10569                self,
10570                model_: torch.fx.GraphModule,
10571                example_inputs_: typing.List[torch.Tensor],
10572            ):
10573                """
10574                The Noop backend does not compile the fx graph it is given.
10575                Instead, it transforms the fx graph so that its functions are
10576                aten operations. It then saves this graph.
10577                """
10578                from torch._inductor.decomposition import select_decomp_table
10579                from torch._subclasses import FakeTensorMode
10580                from torch.fx import Interpreter
10581
10582                fake_mode = FakeTensorMode()
10583
10584                def interpret(*args, **kwargs):
10585                    return Interpreter(model_).run(*args[0:], **kwargs)
10586
10587                fake_flat_tensor_args = [
10588                    fake_mode.from_tensor(x) for x in example_inputs_
10589                ]
10590                fw_module = make_fx(interpret, select_decomp_table())(
10591                    *fake_flat_tensor_args
10592                )
10593                self.model = fw_module
10594                self.example_args = fake_flat_tensor_args
10595                return lambda x: example_inputs_
10596
10597        def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]:
10598            from torch._inductor.debug import DebugContext
10599            from torch._inductor.graph import GraphLowering
10600            from torch._inductor.virtualized import V
10601
10602            cxt = TritonCodeGenTests.NoOpCompilerBackend()
10603            torch._dynamo.optimize(backend=cxt.noop_backend)(fn)(*args)
10604            graph = GraphLowering(cxt.model)
10605            kernels = []
10606            with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()):
10607                graph.run(*(cxt.example_args))
10608                mod = graph.compile_to_module()
10609
10610                for val in mod.__dict__.values():
10611                    if isinstance(
10612                        val, torch._inductor.runtime.triton_heuristics.CachingAutotuner
10613                    ):
10614                        kernels.append(val)
10615
10616            return kernels
10617
10618        def test_divisible_by_16_covers_numel_args(self):
10619            torch._dynamo.reset()
10620
10621            def fn(a: torch.Tensor) -> torch.Tensor:
10622                return torch.sum(a)
10623
10624            kernels = self.get_kernels(fn, [torch.randn([256, 256], device=GPU_TYPE)])
10625            if config.triton.multi_kernel:
10626                self.assertTrue(
10627                    len(kernels) == 4,
10628                    "SUM should result in four kernels when multi-kernel is enabled",
10629                )
10630            else:
10631                self.assertTrue(len(kernels) == 2, "SUM should result in two kernels")
10632
10633            # kernel0 reduces from 256 to (xnumel=8, rnumel=8192), which means it reduces 256 by 256 into an array of
10634            # size 8 by accumulating 8192 elements at once note that rnumel is equal to 512 * 16, so rnumel which is
10635            # at slot 3 should be in the divisible by 16 descriptor
10636            arguments_that_are_divisible_by_16_in_kernel0 = (
10637                kernels[0].triton_meta["configs"][0].divisible_by_16
10638            )
10639            self.assertEqual(arguments_that_are_divisible_by_16_in_kernel0, (0, 1, 3))
10640
10641            # kernel1 reduces from 8 elements to a single scalar.
10642            # Since multi-kernel generate 2 variants for each kernel. The second
10643            # persistent-reduction has index 2.
10644            kernel1_index = 2 if config.triton.multi_kernel else 1
10645            arguments_that_are_divisible_by_16_in_kernel1 = (
10646                kernels[kernel1_index].triton_meta["configs"][0].divisible_by_16
10647            )
10648            self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1))
10649            torch._dynamo.reset()
10650
10651        @config.patch(assume_aligned_inputs=False)
10652        def test_codegen_config_option_dont_assume_alignment(self):
10653            def fn(x: torch.Tensor) -> torch.Tensor:
10654                return x.sin() + x.cos()
10655
10656            # We want code that assumes alignment if the initial input is 16-byte aligned
10657            for offset in (0, 1, 2, 3, 4):
10658                base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
10659                inps = torch.as_strided(base, (64, 64), (64, 1), offset)
10660                torch._dynamo.reset()
10661                kernels = self.get_kernels(fn, [inps])
10662                arguments_that_are_divisible_by_16 = (
10663                    kernels[0].triton_meta["configs"][0].divisible_by_16
10664                )
10665
10666                #             NO_ALIGN ALIGN     ALIGN
10667                # def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr)
10668
10669                if offset % 4 == 0:
10670                    expected_aligned = (0, 1, 2)
10671                else:
10672                    expected_aligned = (1, 2)
10673                self.assertEqual(arguments_that_are_divisible_by_16, expected_aligned)
10674
10675            # If input isn't a view, storage offset != , inductor will assume alignment.
10676            torch._dynamo.reset()
10677            inp = torch.randn((64, 64), device=GPU_TYPE)
10678            kernels = self.get_kernels(fn, [inp])
10679            arguments_that_are_divisible_by_16 = (
10680                kernels[0].triton_meta["configs"][0].divisible_by_16
10681            )
10682            self.assertEqual(arguments_that_are_divisible_by_16, (0, 1, 2))
10683
10684        def test_optimize_indexing_dtype(self):
10685            def fn(x: torch.Tensor) -> torch.Tensor:
10686                return aten.upsample_bilinear2d.vec(x, None, True, [2.0, 2.0])
10687
10688            fn_opt = torch._dynamo.optimize("inductor")(fn)
10689            inps = [torch.randn(2, 4, 16, 16, device=GPU_TYPE)]
10690            code = run_and_get_triton_code(fn_opt, *inps)
10691            self.assertTrue("to(tl.int32)" in code)
10692            self.assertFalse("to(tl.int64)" in code)
10693
10694            self.assertEqual(fn_opt(*inps), fn(*inps))
10695
10696        def test_optimize_indexing_dtype_with_constraint(self):
10697            def fn1(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
10698                x = torch.arange(0, b.shape[0], device=GPU_TYPE)
10699                y = ((x + x) / 3).int()
10700                return a[y.to(torch.int64)]
10701
10702            def fn2(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
10703                torch._check_is_size(b.shape[0])
10704                torch._check(b.shape[0] >= 2)
10705                torch._check(b.shape[0] <= 100)
10706                return fn1(a, b)
10707
10708            fn1_opt = torch._dynamo.optimize("inductor")(fn1)
10709            fn2_opt = torch._dynamo.optimize("inductor")(fn2)
10710
10711            a = torch.rand([100, 100], device=GPU_TYPE)
10712            b = torch.rand([100], device=GPU_TYPE)
10713            torch._dynamo.mark_dynamic(b, 0)
10714            inps = [a, b]
10715
10716            code1 = run_and_get_triton_code(fn1_opt, *inps)
10717            code2 = run_and_get_triton_code(fn2_opt, *inps)
10718
10719            # The function with the constrained tensor should be optimized, but
10720            # the other should not:
10721            self.assertTrue("to(tl.int64)" in code1)
10722            self.assertTrue("to(tl.int32)" in code2)
10723            self.assertFalse("to(tl.int64)" in code2)
10724
10725            self.assertEqual(fn1_opt(*inps), fn1(*inps))
10726            self.assertEqual(fn2_opt(*inps), fn1(*inps))
10727
10728        def test_constant_folding_deallocation(self):
10729            import torch._inductor
10730
10731            def fn():
10732                li = []
10733                for i in range(10):
10734                    x = torch.full([100], i)
10735                    x = x + 1
10736                    li.append(x)
10737
10738                return li
10739
10740            mod = make_fx(fn)()
10741
10742            live_tensors = WeakTensorKeyDictionary()
10743            max_live_tensors = 0
10744
10745            class LiveTensors(TorchDispatchMode):
10746                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
10747                    nonlocal live_tensors
10748                    nonlocal max_live_tensors
10749
10750                    kwargs = kwargs if kwargs else {}
10751                    for arg in pytree.arg_tree_leaves(*args, **kwargs):
10752                        if isinstance(arg, torch.Tensor):
10753                            live_tensors[arg] = True
10754
10755                    out = func(*args, **kwargs)
10756                    if not isinstance(out, torch.Tensor):
10757                        return out
10758
10759                    live_tensors[out] = True
10760                    max_live_tensors = max(max_live_tensors, len(live_tensors))
10761                    return out
10762
10763            mode = LiveTensors()
10764            from torch._inductor.fx_passes.joint_graph import UniformValueConstantFolder
10765
10766            with mode:
10767                UniformValueConstantFolder(mod).run()
10768
10769            # there are a couple extra tensors created in `insertable_tensor_check`
10770            self.assertTrue(max_live_tensors == 4)
10771
10772        # See https://github.com/pytorch/pytorch/issues/100348
10773        def test_inductor_detach_view(self):
10774            def fn(x: torch.Tensor) -> torch.Tensor:
10775                a = x * 2
10776                return a, a.detach()
10777
10778            fn_opt = torch._dynamo.optimize("inductor")(fn)
10779            inp = torch.ones(2, 2, requires_grad=True, device=GPU_TYPE)
10780            inp_ref = inp.clone().detach().requires_grad_(True)
10781            out_ref = fn(inp_ref)
10782            out = fn_opt(inp)
10783            out_ref[0].sum().backward()
10784            out[0].sum().backward()
10785            self.assertEqual(inp.grad, inp_ref.grad)
10786
10787        @skipIfRocm  # asserts not implemented in Rocm yet
10788        def test_optimize_indexing_assert(self):
10789            def has_indirect(code, tl_fn: str):
10790                self.assertTrue(
10791                    tl_fn in code,
10792                    msg=f"{tl_fn} not present:\n{code}",
10793                )
10794                for line in code.split("\n"):
10795                    if tl_fn in line:
10796                        stmt = line.split(tl_fn)[-1]
10797                        # indirect indexing involves a `tmp` variable
10798                        self.assertTrue(
10799                            "tmp" in stmt,
10800                            msg=f"Indirect indexing not present in code:\n{line}",
10801                        )
10802
10803            def has_assert(code, lower: bool, upper: bool):
10804                self.assertIn(
10805                    "device_assert", code, msg=f"No device asert found:\n{code}"
10806                )
10807                for line in code.split("\n"):
10808                    if "device_assert" in line:
10809                        self.assertTrue(
10810                            ("0 <= " in line) is lower,
10811                            msg=f"Lower bound {'' if lower else 'not '}elided:{line}",
10812                        )
10813                        self.assertTrue(
10814                            (" < " in line) is upper,
10815                            msg=f"Upper bound {'' if upper else 'not '}elided:{line}",
10816                        )
10817
10818            def fn(x: torch.Tensor) -> torch.Tensor:
10819                s = 1.0 * torch.arange(x.shape[0], device=x.device)
10820                return x[s.long()]
10821
10822            # aten.index
10823            for dynamic in (False, True):
10824                fn_opt = torch.compile(fn, dynamic=dynamic)
10825
10826                x = torch.randn(8, device=GPU_TYPE)
10827                code = run_and_get_triton_code(fn_opt, x)
10828                self.assertEqual(fn_opt(x), fn(x), msg=f"{dynamic=}")
10829
10830                # Check that there's indirect indexing...
10831                has_indirect(code, tl_fn="tl.load")
10832                if not dynamic:
10833                    # We elide the assert for static shapes
10834                    self.assertNotIn("device_assert", code)
10835                else:
10836                    # ...but we generate an upper bound for dynamic shapes
10837                    has_assert(code, lower=False, upper=True)
10838
10839            def fn(a, z, b, idx0, idx1):
10840                idx2 = torch.arange(a.shape[-1], device=a.device)
10841                a.index_put_((z, idx0, idx1, idx2), b, accumulate=True)
10842                return a
10843
10844            # aten.index_put
10845            for dynamic in (False, True):
10846                fn_opt = torch.compile(fn, dynamic=dynamic)
10847                a = torch.randn(1, 32, 32, 4, device=GPU_TYPE)
10848                z = torch.zeros((), dtype=torch.int64, device=GPU_TYPE)
10849                b = torch.randn(33, 1, device=GPU_TYPE)
10850                idx0 = torch.randint(32, (33,), device=GPU_TYPE).view(33, 1, 1)
10851                idx1 = torch.randint(32, (33,), device=GPU_TYPE).view(33, 1)
10852                inps = (a.clone(), z, b, idx0, idx1)
10853                code = run_and_get_triton_code(fn_opt, *inps)
10854
10855                # Correctness
10856                out_opt = fn_opt(a.clone(), z, b, idx0, idx1)
10857                out = fn(a.clone(), z, b, idx0, idx1)
10858                self.assertEqual(out_opt, out, msg=f"{dynamic=}")
10859
10860                # We have an indirect store via atomic_add
10861                has_indirect(code, tl_fn="tl.atomic_add")
10862                # We cannot elide he assert in this case
10863                has_assert(code, lower=True, upper=True)
10864
10865        def test_not_materialize_pointwise_reduction(self):
10866            def fn(a, b):
10867                return (a - b).sum(dim=-1).amax(dim=-1)
10868
10869            N = 16
10870            K = 7
10871            fn_opt = torch._dynamo.optimize("inductor")(fn)
10872            inps = [
10873                torch.randn(N, 1, K, device=GPU_TYPE),
10874                torch.randn(1, N, K, device=GPU_TYPE),
10875            ]
10876            code = run_and_get_triton_code(fn_opt, *inps)
10877            self.assertEqual(
10878                code.count("tl.store"), 2 if config.triton.multi_kernel else 1
10879            )
10880            self.assertTrue("out_ptr1" in code)
10881            self.assertFalse("out_ptr0" in code)
10882            self.assertEqual(fn_opt(*inps), fn(*inps))
10883
10884        def test_numpy_on_gpu(self):
10885            x = np.arange(10, dtype=np.float32)
10886
10887            @torch.compile
10888            def fn(x):
10889                return np.sin(x)
10890
10891            def fn_gpu(x):
10892                with torch.device(GPU_TYPE):
10893                    return fn(x)
10894
10895            r = fn_gpu(x)
10896            code = run_and_get_triton_code(fn_gpu, x)
10897            self.assertIn("tl_math.sin", code)
10898            self.assertEqual(type(r), np.ndarray)
10899            self.assertEqual(r, np.sin(x))
10900
10901        def test_numpy_autograd(self):
10902            def my_torch(x):
10903                y = torch.cat([torch.sin(x) ** 2, torch.max(x)[None]])
10904                return y.sum()
10905
10906            def my_np(x):
10907                y = np.concatenate([np.sin(x) ** 2, np.max(x)[None]])
10908                return np.sum(y)
10909
10910            @torch.compile
10911            def wrapper(x):
10912                return torch.compiler.wrap_numpy(my_np)(x)
10913
10914            @torch.compile
10915            def wrapper2(x):
10916                x = x.numpy()
10917                y = my_np(x)
10918                return torch.from_numpy(y)
10919
10920            x_np = torch.arange(8, dtype=torch.float32, requires_grad=True)
10921            x = torch.arange(8, dtype=torch.float32, requires_grad=True)
10922            out_np = wrapper(x_np)
10923            out = my_torch(x)
10924            self.assertEqual(out, out_np)
10925
10926            x2_np = torch.arange(8, dtype=torch.float32, requires_grad=True)
10927            out2_np = wrapper2(x2_np)
10928            self.assertEqual(out, out2_np)
10929
10930            out_np.backward()
10931            out.backward()
10932            self.assertEqual(x.grad, x_np.grad)
10933
10934            out2_np.backward()
10935            self.assertEqual(x.grad, x2_np.grad)
10936
10937        # Disable constant propagation, so we isolate value range analysis
10938        @patch.object(config, "constant_and_index_propagation", False)
10939        @patch.object(config, "joint_graph_constant_folding", False)
10940        def test_cant_optimize_compute(self):
10941            def ones():
10942                return torch.ones([4], device=GPU_TYPE)
10943
10944            def suffix(inp):
10945                return (inp.to(torch.int64) + 1).to(torch.float64)
10946
10947            ten = torch.rand([4], device=GPU_TYPE)
10948
10949            for foo in (
10950                lambda x: x + 2147483657,
10951                lambda x: torch.where(x < 0, ones(), ones() - 2) * (-(2 ** (40))),
10952                lambda x: x + ten,
10953                lambda x: x + ten.sum(),
10954            ):
10955
10956                def fn():
10957                    return suffix(foo(ones()))
10958
10959                fn_opt = torch._dynamo.optimize("inductor")(fn)
10960                code = run_and_get_triton_code(fn_opt)
10961
10962                # this cannot be optimized away, value too large
10963                self.assertTrue("to(tl.int64)" in code)
10964                self.assertEqual(fn_opt(), fn())
10965
10966        # Disable constant propagation, so we isolate value range analysis
10967        @patch.object(config, "constant_and_index_propagation", False)
10968        @patch.object(config, "joint_graph_constant_folding", False)
10969        def test_optimize_compute(self):
10970            def ones():
10971                return torch.ones([4], device=GPU_TYPE)
10972
10973            def suffix(inp):
10974                return (inp.to(torch.int64) + 1).to(torch.float64)
10975
10976            for foo in (
10977                lambda x: x + 500,
10978                lambda x: torch.where(x < 0, ones(), ones() - 2) * (-(2 ** (20))),
10979                lambda x: x / 30,
10980            ):
10981
10982                def fn():
10983                    return suffix(foo(ones()))
10984
10985                fn_opt = torch._dynamo.optimize("inductor")(fn)
10986                code = run_and_get_triton_code(fn_opt)
10987
10988                # this can be optimized away, value too large
10989                self.assertTrue("to(tl.int64)" not in code)
10990                self.assertTrue("to(tl.int32)" in code)
10991
10992                self.assertEqual(fn_opt(), fn())
10993
10994        @config.patch("triton.use_block_ptr", False)
10995        def test_evict_last_non_coalesced_loads(self):
10996            @torch.compile
10997            def f(a, b):
10998                return (a * b).sum(dim=-1)
10999
11000            N = 512
11001            inps = (
11002                torch.randn(N, N, N, device=GPU_TYPE).permute(2, 1, 0),
11003                torch.randn(N, N, N, device=GPU_TYPE).permute(1, 2, 0),
11004            )
11005            code = run_and_get_triton_code(f, *inps)
11006            lines = [line for line in code.split("\n") if "tl.load" in line]
11007            if config.triton.multi_kernel:
11008                # the first 2 lines are generated for the persistent reduction
11009                # variant.
11010                self.assertExpectedInline(
11011                    "\n".join(lines),
11012                    """\
11013    tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
11014    tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, other=0.0)
11015        tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
11016        tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
11017                )
11018            else:
11019                self.assertExpectedInline(
11020                    "\n".join(lines),
11021                    """\
11022        tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
11023        tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
11024                )
11025
11026        @skipIfRocm
11027        @config.patch("triton.use_block_ptr", True)
11028        def test_evict_last_non_coalesced_loads_block_ptr(self):
11029            @torch.compile
11030            def f(a, b):
11031                return (a * b).sum(dim=-1)
11032
11033            N = 512
11034            inps = (
11035                torch.randn(N, N, N, device=GPU_TYPE).permute(2, 1, 0),
11036                torch.randn(N, N, N, device=GPU_TYPE).permute(1, 2, 0),
11037            )
11038            code = run_and_get_triton_code(f, *inps)
11039            lines = [line for line in code.split("\n") if "tl.load" in line]
11040
11041            if config.triton.multi_kernel:
11042                # the first 2 lines are generated for the persistent reduction
11043                # variant.
11044                self.assertExpectedInline(
11045                    "\n".join(lines),
11046                    """\
11047    tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
11048    tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, RBLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero')
11049        tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
11050        tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""",  # noqa: B950 line too long
11051                )
11052            else:
11053                self.assertExpectedInline(
11054                    "\n".join(lines),
11055                    """\
11056        tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
11057        tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""",
11058                )
11059
11060        # Disable index propagation, so the indirect indexing isn't optimized away
11061        @patch.object(config, "constant_and_index_propagation", False)
11062        def test_computed_indirect_mask(self):
11063            def fn(x, n):
11064                tmp = torch.arange(n, device=x.device)
11065                return x[tmp] + 1
11066
11067            x = torch.randn(8, device=GPU_TYPE)
11068            fn_opt = torch.compile(fn)
11069            code = run_and_get_triton_code(fn_opt, x, 8)
11070            # load should be masked
11071            self.assertTrue("tl.load(in_ptr0 + (tmp0), xmask" in code)
11072            self.assertEqual(fn(x, 8), fn_opt(x, 8))
11073
11074        def test_kernel_names_descriptive(self):
11075            @torch._dynamo.optimize("inductor")
11076            def fn1(x):
11077                return x.cos().sin()
11078
11079            @torch._dynamo.optimize("inductor")
11080            def fn2(x):
11081                x = torch.mm(x, x)
11082                x = torch.softmax(x, dim=1)
11083                return x
11084
11085            mod = nn.Sequential(
11086                nn.Linear(4, 4),
11087                nn.LayerNorm(4),
11088                nn.ReLU(),
11089            ).to(device=GPU_TYPE)
11090
11091            @torch._dynamo.optimize("inductor")
11092            def fn3(x):
11093                return mod(x)
11094
11095            func_and_kernel_aten = [
11096                (fn1, "triton_poi_fused_cos_sin", (torch.randn(8, device=GPU_TYPE),)),
11097                (
11098                    fn2,
11099                    "triton_poi_fused__softmax",
11100                    (torch.randn(4, 4, device=GPU_TYPE),),
11101                ),
11102                (
11103                    fn3,
11104                    "triton_poi_fused_native_layer_norm_relu",
11105                    (torch.randn(4, 4, device=GPU_TYPE),),
11106                ),
11107            ]
11108            func_and_kernel_torch = [
11109                (fn1, "triton_poi_fused_cos_sin", (torch.randn(8, device=GPU_TYPE),)),
11110                (
11111                    fn2,
11112                    "triton_poi_fused_softmax",
11113                    (torch.randn(4, 4, device=GPU_TYPE),),
11114                ),
11115                (
11116                    fn3,
11117                    "triton_poi_fused_layer_norm_relu"
11118                    if torch._dynamo.config.inline_inbuilt_nn_modules
11119                    else "triton_poi_fused_LayerNorm_ReLU",
11120                    (torch.randn(4, 4, device=GPU_TYPE),),
11121                ),
11122            ]
11123
11124            def test_funcs(func_and_kernel):
11125                with torch.no_grad():
11126                    for fn, kernel_name, inps in func_and_kernel:
11127                        code = run_and_get_triton_code(fn, *inps)
11128                        if kernel_name not in code:
11129                            print(code)
11130                        self.assertTrue(kernel_name in code)
11131
11132            test_funcs(func_and_kernel_aten)
11133            patch.object(config.triton, "descriptive_names", "torch")(test_funcs)(
11134                func_and_kernel_torch
11135            )
11136
11137        @patch.object(config, "profile_bandwidth", True)
11138        def test_bandwidth_profiler(self):
11139            @torch._dynamo.optimize("inductor")
11140            def fn(x):
11141                x = x.cos()
11142                x = x.cos()
11143                x = torch.mm(x, x)
11144                x = x.sin()
11145                x = x.relu()
11146                return x
11147
11148            inp = torch.randn(4, 4, device=GPU_TYPE)
11149            code = run_and_get_triton_code(fn, inp)
11150            fn(inp)
11151            self.assertTrue("start_graph" in code)
11152            self.assertTrue("end_graph" in code)
11153
11154        def test_split_op_with_sym(self):
11155            def fn(x: torch.Tensor) -> torch.Tensor:
11156                # split(tensor, sympy.Integer), split(tensor, sympy.Expr)
11157                return torch.split(x, x.shape[0]), torch.split(x, x.shape[0] // 2)
11158
11159            for dynamic_shapes in [True, False]:
11160                with torch._dynamo.config.patch(dynamic_shapes=dynamic_shapes):
11161                    torch._dynamo.reset()
11162                    fn_opt = torch._dynamo.optimize("inductor", dynamic=dynamic_shapes)(
11163                        fn
11164                    )
11165                    inps = torch.randn([5, 5])
11166                    fn_opt(inps)
11167
11168        @skipIfRocm
11169        @unittest.skipIf(IS_FBCODE, "fbcode system python does not provide torch")
11170        def test_indirect_device_assert(self):
11171            dir_path = os.path.dirname(os.path.realpath(__file__))
11172            test_path = os.path.join(dir_path, "indirect_assert_helper.py")
11173            fns = ("first_arg", "store", "second_arg", "same_pm_one", "same_pp_one")
11174
11175            def test(fn, ndims, dyn_shape, one_size=False):
11176                proc = subprocess.Popen(
11177                    [
11178                        sys.executable,
11179                        test_path,
11180                        fn,
11181                        str(ndims),
11182                        str(dyn_shape),
11183                        str(one_size),
11184                    ],
11185                    stdout=subprocess.PIPE,
11186                    stderr=subprocess.PIPE,
11187                    env={**os.environ, "MKL_THREADING_LAYER": "GNU"},
11188                )
11189                stderr = proc.communicate()[1]
11190                self.assertTrue(
11191                    any(
11192                        "out of bounds" in err.decode("utf-8")
11193                        for err in stderr.splitlines()
11194                    ),
11195                    f"{fn}, {ndims}, {dyn_shape}, {one_size}",
11196                )
11197
11198            for fn, ndims, dyn_shape in itertools.product(fns, (2, 3), (True, False)):
11199                test(fn, ndims, dyn_shape)
11200
11201            test("first_arg", 2, False, True)
11202
11203            for fn, dyn_shape in itertools.product(
11204                ("upper1", "upper2", "lower1", "lower2"), (True, False)
11205            ):
11206                test(fn, 2, dyn_shape)
11207
11208        @patch("torch._inductor.config.comment_origin", True)
11209        @patch("torch._functorch.config.max_dist_from_bw", 0)
11210        def test_inductor_sequence_nr(self):
11211            class Model(torch.nn.Module):
11212                def __init__(self):
11213                    super().__init__()
11214                    self.conv1 = torch.nn.Conv2d(
11215                        in_channels=16,
11216                        out_channels=16,
11217                        kernel_size=(1, 1),
11218                        stride=1,
11219                        padding="same",
11220                        bias=True,
11221                    )
11222                    self.bn1 = torch.nn.BatchNorm2d(num_features=16)
11223                    self.relu1 = torch.nn.ReLU()
11224                    self.loss_fn = torch.nn.L1Loss()
11225
11226                def forward(self, x, target):
11227                    y = x
11228                    x = self.conv1(x)
11229                    x = self.bn1(x)
11230                    x = self.relu1(x)
11231                    x = x + y
11232                    x = torch.flatten(x)
11233                    output = self.loss_fn(x, target)
11234                    return (output,)
11235
11236            def get_triton_codegen(optimized_module, args):
11237                def run_with_backward():
11238                    result = optimized_module(*args)
11239                    result[0].backward()
11240                    return result
11241
11242                res, (fwd_code, bwd_code) = run_and_get_code(run_with_backward)
11243                return fwd_code, bwd_code
11244
11245            x = torch.rand(100, 16, 32, 32, requires_grad=True, device=GPU_TYPE)
11246            target = torch.rand(1, device=GPU_TYPE)
11247            args = [x, target]
11248            model = Model().to(device=GPU_TYPE)
11249            opt_model = torch.compile(model)
11250            fwd_code, bwd_code = get_triton_codegen(opt_model, args)
11251
11252            bwd_seq_nr_set = set()
11253            fwd_seq_nr_set = set()
11254            for idx, code in enumerate([fwd_code, bwd_code]):
11255                seq_nr_set = bwd_seq_nr_set if idx > 0 else fwd_seq_nr_set
11256                prefix = "BWD" if idx > 0 else "FWD"
11257                for line in code.split("\n"):
11258                    if "seq_nr" in line:
11259                        res = re.search(r"seq_nr:(\d+)", line)
11260                        if res:
11261                            seq_nr_set.add(int(res.group(1)))
11262            self.assertTrue(bwd_seq_nr_set.issubset(fwd_seq_nr_set))
11263
11264        @config.patch(
11265            {
11266                "coordinate_descent_tuning": True,
11267                "triton.unique_kernel_names": True,
11268                "benchmark_kernel": True,
11269            }
11270        )
11271        @skipIfRocm
11272        @expectedFailureXPU
11273        @unittest.skipIf(
11274            torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0),
11275            "Triton does not support fp8 on A100",
11276        )
11277        def test_red_followed_by_transposed_pointwise(self):
11278            bs = 26624
11279            dim = 1024
11280
11281            @torch.compile(dynamic=False)
11282            def f(in1, in2, a, b):
11283                out = torch.nn.functional.silu(in1) * in2
11284                out_row = (out / out.amax(dim=1, keepdim=True)).to(torch.float8_e4m3fn)
11285                out_col = (out / out.amax(dim=0, keepdim=True)).to(torch.float8_e4m3fn)
11286
11287                # setup strides for _scaled_mm
11288                out_row = out_row.contiguous()
11289                out_col = out_col.t().contiguous().t()
11290
11291                return (
11292                    torch._scaled_mm(out_row, a, out_dtype=torch.bfloat16)[0],
11293                    torch._scaled_mm(b, out_col, out_dtype=torch.bfloat16)[0],
11294                )
11295
11296            in1 = torch.randn((bs, dim), dtype=torch.bfloat16, device=GPU_TYPE)
11297            in2 = torch.randn((bs, dim), dtype=torch.bfloat16, device=GPU_TYPE)
11298            a = (
11299                torch.randn((dim, dim), dtype=torch.bfloat16, device=GPU_TYPE)
11300                .t()
11301                .to(torch.float8_e4m3fn)
11302            )
11303            b = torch.randn((dim, bs), dtype=torch.bfloat16, device=GPU_TYPE).to(
11304                torch.float8_e4m3fn
11305            )
11306
11307            # warmup
11308            _, (wrapper,) = run_and_get_code(f, in1, in2, a, b)
11309
11310            # Previously indcutor decide reduction hint for a reduction kernel without considering
11311            # the pointwise nodes. That will cause the third reduction kernel in this wrapper to be a
11312            # persistent inner reduction and cause bad perf.
11313            #
11314            # We fix that by making the third reduction a non-persistent reduction
11315            # and improve the perf by 4.14x (451us -> 109us)
11316            self.assertEqual(3, wrapper.count("def triton_red_"))
11317            self.assertEqual(0, wrapper.count("def triton_per_"))
11318
11319            if DO_PERF_TEST:
11320                with torch.profiler.profile(
11321                    activities=[torch.profiler.ProfilerActivity.CUDA]
11322                ) as p:
11323                    for _ in range(1000):
11324                        f(in1, in2, a, b)
11325
11326                print(p.key_averages().table(max_name_column_width=200))
11327
11328    class RNNTest(TestCase):
11329        device_type = GPU_TYPE
11330
11331        class Model(torch.nn.Module):
11332            def __init__(self):
11333                super().__init__()
11334                self.gru = torch.nn.GRU(16, 16, batch_first=True)
11335
11336            def forward(self, x):
11337                return self.gru(x)
11338
11339        @expectedFailureXPU
11340        def test_rnn_compile_safe(self):
11341            device = torch.device(GPU_TYPE)
11342            model = RNNTest.Model().to(device)
11343            model = torch._dynamo.optimize("inductor")(model)
11344            x = torch.rand(1024, 20, 16).to(device)
11345            model(x)
11346
11347    class NanCheckerTest(TestCase):
11348        @config.patch("nan_asserts", True)
11349        def test_nan_checker_pass(self):
11350            def f(x):
11351                return torch.softmax(x, dim=-1)
11352
11353            x = torch.randn(2, 1024, device=GPU_TYPE)
11354            ref = f(x)
11355            actual, (code,) = run_and_get_code(torch.compile(f), x)
11356            self.assertTrue(torch.allclose(ref, actual))
11357            self.assertTrue("# make sure graph inputs are not nan/inf" in code)
11358            self.assertTrue(
11359                re.search(r"assert not .*\.isnan\(\)\.any\(\).item\(\)", code)
11360                is not None
11361            )
11362            self.assertTrue(
11363                re.search(r"assert not .*\.isinf\(\)\.any\(\).item\(\)", code)
11364                is not None
11365            )
11366
11367        @config.patch("nan_asserts", True)
11368        def test_nan_checker_fail(self):
11369            def f(x):
11370                return torch.softmax(x, dim=-1)
11371
11372            x = torch.randn(2, 1024, device=GPU_TYPE)
11373            x[0, 0] = float("nan")
11374            with self.assertRaises(AssertionError):
11375                torch.compile(f)(x)
11376
11377
11378if HAS_CPU:
11379
11380    class TestFull(TestCase):
11381        def test_full_dtype(self):
11382            pytypes = (
11383                bool,
11384                int,
11385                float,
11386                # TODO: Triton's JITFunction._type_of has no support for complex
11387                # complex,
11388            )
11389
11390            dtypes = (
11391                torch.bool,
11392                torch.int32,
11393                torch.int64,
11394                torch.float32,
11395                torch.float64,
11396                None,
11397                # torch.complex64,
11398                # torch.complex128,
11399            )
11400
11401            def fn(pytype, dtype):
11402                if pytype is bool:
11403                    fill_value = True
11404                elif pytype is int:
11405                    fill_value = 42
11406                elif pytype is float:
11407                    fill_value = 42.0
11408                else:
11409                    raise AssertionError(f"Unexpected Python type: {pytype}")
11410
11411                return torch.full(
11412                    (4, 6), fill_value, dtype=dtype, device=torch.device("cpu")
11413                )
11414
11415            fn_opt = torch._dynamo.optimize("inductor")(fn)
11416
11417            for pytype, dtype in itertools.product(pytypes, dtypes):
11418                with enable_python_dispatcher():
11419                    with torch.no_grad():
11420                        ret_opt = fn_opt(pytype, dtype)
11421
11422                self.assertEqual(ret_opt, fn(pytype, dtype))
11423
11424
11425if __name__ == "__main__":
11426    from torch._inductor.test_case import run_tests
11427
11428    if HAS_CPU or HAS_GPU:
11429        run_tests(needs="filelock")
11430